diff --git a/Add_Metadata.ipynb b/Add_Metadata.ipynb new file mode 100644 index 0000000..e74743c --- /dev/null +++ b/Add_Metadata.ipynb @@ -0,0 +1,329 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Add_Metadata", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "laCUI2hzapKv" + }, + "source": [ + "- Reference: https://github.com/margaretmz/selfie2anime-with-tflite/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb. \n", + "- TensorFlow Lite meatdata: https://www.tensorflow.org/lite/convert/metadata.\n", + "- Authored by: Sayak.\n", + "- Updated on: December 05 2020." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lIYdn1woOS1n", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "bea2e589-410d-4957-f388-3b37ad7733d8" + }, + "source": [ + "!pip install -q tflite-support" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[K |████████████████████████████████| 1.0MB 5.6MB/s \n", + "\u001b[K |████████████████████████████████| 194kB 12.5MB/s \n", + "\u001b[?25h" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hSBlosV-Wim7", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d73a0356-b4f9-44b1-a310-610aa81a4e0d" + }, + "source": [ + "import os\n", + "import tensorflow as tf\n", + "from absl import flags\n", + "\n", + "from tflite_support import flatbuffers\n", + "from tflite_support import metadata as _metadata\n", + "from tflite_support import metadata_schema_py_generated as _metadata_fb\n", + "\n", + "print(tf.__version__)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "2.3.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v_czMsARW1SW" + }, + "source": [ + "!mkdir model_without_metadata\n", + "!mkdir model_with_metadata" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "8_Qa9lo4W6zV" + }, + "source": [ + "!wget -q https://github.com/sayakpaul/MIRNet-TFLite/releases/download/v0.1.0/fixed_shape.zip\n", + "!unzip -q fixed_shape.zip\n", + "\n", + "!mv fixed_shape/*.tflite model_without_metadata/" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "InH13DM9cdqP" + }, + "source": [ + "# This is where we will export a new .tflite model file with metadata, and a .json file with metadata info\n", + "EXPORT_DIR = \"model_with_metadata\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "E4ta4ByLYST3" + }, + "source": [ + "class MetadataPopulatorForMIRNet(object):\n", + " \"\"\"Populates the metadata for the MIRNet model.\"\"\"\n", + "\n", + " def __init__(self, model_file):\n", + " self.model_file = model_file\n", + " self.metadata_buf = None\n", + "\n", + " def populate(self):\n", + " \"\"\"Creates metadata and then populates it for a low-light image enhancement model.\"\"\"\n", + " self._create_metadata()\n", + " self._populate_metadata()\n", + " \n", + " def _create_metadata(self):\n", + " \"\"\"Creates the metadata for the CartoonGAN model.\"\"\"\n", + "\n", + " # Creates model info.\n", + " model_meta = _metadata_fb.ModelMetadataT()\n", + " model_meta.name = \"MIRNet\" \n", + " model_meta.description = (\"Enhances a low-light image. Reference: https://arxiv.org/pdf/2003.06792v2.pdf. TFLiteConverter used from TensorFlow 2.3.0.\")\n", + " model_meta.version = \"v1\"\n", + " model_meta.author = \"Sayak\"\n", + " model_meta.license = (\"Apache License. Version 2.0 \"\n", + " \"http://www.apache.org/licenses/LICENSE-2.0.\")\n", + "\n", + " # Creates info for the input, normal image.\n", + " input_image_meta = _metadata_fb.TensorMetadataT()\n", + " input_image_meta.name = \"low_light_image\"\n", + " # if self.model_type==\"other\":\n", + " input_image_meta.description = (\n", + " \"The expected image is 400 x 400, with three channels \"\n", + " \"(red, blue, and green) per pixel. Each value in the tensor is between\"\n", + " \" 0 and 1.\")\n", + " input_image_meta.content = _metadata_fb.ContentT()\n", + " input_image_meta.content.contentProperties = (\n", + " _metadata_fb.ImagePropertiesT())\n", + " input_image_meta.content.contentProperties.colorSpace = (\n", + " _metadata_fb.ColorSpaceType.RGB)\n", + " input_image_meta.content.contentPropertiesType = (\n", + " _metadata_fb.ContentProperties.ImageProperties)\n", + " input_image_normalization = _metadata_fb.ProcessUnitT()\n", + " input_image_normalization.optionsType = (\n", + " _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n", + " input_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n", + " input_image_normalization.options.mean = [0.0]\n", + " input_image_normalization.options.std = [255.0]\n", + " input_image_meta.processUnits = [input_image_normalization]\n", + " input_image_stats = _metadata_fb.StatsT()\n", + " input_image_stats.max = [1.0]\n", + " input_image_stats.min = [0.0]\n", + " input_image_meta.stats = input_image_stats\n", + "\n", + "\n", + " # Creates output info, cartoonized image\n", + " output_image_meta = _metadata_fb.TensorMetadataT()\n", + " output_image_meta.name = \"enhanced_image\"\n", + " output_image_meta.description = \"Image enhanced.\"\n", + " output_image_meta.content = _metadata_fb.ContentT()\n", + " output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()\n", + " output_image_meta.content.contentProperties.colorSpace = (\n", + " _metadata_fb.ColorSpaceType.RGB)\n", + " output_image_meta.content.contentPropertiesType = (\n", + " _metadata_fb.ContentProperties.ImageProperties)\n", + " output_image_normalization = _metadata_fb.ProcessUnitT()\n", + " output_image_normalization.optionsType = (\n", + " _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n", + " output_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n", + " output_image_normalization.options.mean = [0.0]\n", + " output_image_normalization.options.std = [1.0] \n", + " output_image_meta.processUnits = [output_image_normalization]\n", + " output_image_stats = _metadata_fb.StatsT()\n", + " output_image_stats.max = [255.0]\n", + " output_image_stats.min = [0.0]\n", + " output_image_meta.stats = output_image_stats\n", + "\n", + " # Creates subgraph info.\n", + " subgraph = _metadata_fb.SubGraphMetadataT()\n", + " subgraph.inputTensorMetadata = [input_image_meta] \n", + " subgraph.outputTensorMetadata = [output_image_meta] \n", + " model_meta.subgraphMetadata = [subgraph]\n", + "\n", + " b = flatbuffers.Builder(0)\n", + " b.Finish(\n", + " model_meta.Pack(b),\n", + " _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)\n", + " self.metadata_buf = b.Output()\n", + "\n", + " def _populate_metadata(self):\n", + " \"\"\"Populates metadata to the model file.\"\"\"\n", + " populator = _metadata.MetadataPopulator.with_model_file(self.model_file)\n", + " populator.load_metadata_buffer(self.metadata_buf)\n", + " populator.populate()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "vRlUqyFMcNyy" + }, + "source": [ + "def populate_metadata(model_file):\n", + " \"\"\"Populates the metadata using the populator specified.\n", + " Args:\n", + " model_file: valid path to the model file.\n", + " \"\"\"\n", + "\n", + " # Populates metadata for the model.\n", + " model_file_basename = os.path.basename(model_file)\n", + " export_path = os.path.join(EXPORT_DIR, model_file_basename)\n", + " tf.io.gfile.copy(model_file, export_path, overwrite=True)\n", + "\n", + " populator = MetadataPopulatorForMIRNet(export_path) \n", + " populator.populate()\n", + "\n", + " # Displays the metadata that was just populated into the tflite model.\n", + " displayer = _metadata.MetadataDisplayer.with_model_file(export_path)\n", + " export_json_file = os.path.join(\n", + " EXPORT_DIR,\n", + " os.path.splitext(model_file_basename)[0] + \".json\")\n", + " json_file = displayer.get_metadata_json()\n", + " with open(export_json_file, \"w\") as f:\n", + " f.write(json_file)\n", + " print(\"Finished populating metadata and associated file to the model:\")\n", + " print(export_path)\n", + " print(\"The metadata json file has been saved to:\")\n", + " print(os.path.join(EXPORT_DIR,\n", + " os.path.splitext(model_file_basename)[0] + \".json\"))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "YVSJFR4HcUuI", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a2cbe140-abf7-4541-e663-329b156d50c5" + }, + "source": [ + "quantization = \"int8\" #@param [\"dr\", \"int8\", \"fp16\"]\n", + "tflite_model_path = f\"mirnet_{quantization}.tflite\" \n", + "MODEL_FILE = \"/content/model_without_metadata/{}\".format(tflite_model_path)\n", + "populate_metadata(MODEL_FILE)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Finished populating metadata and associated file to the model:\n", + "model_with_metadata/mirnet_int8.tflite\n", + "The metadata json file has been saved to:\n", + "model_with_metadata/mirnet_int8.json\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Bcn0-gbW3ep9", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8d3e9e86-f5ce-41b1-c6b1-5e55baf2fedb" + }, + "source": [ + "!tar cvf model_with_metadata.tar.gz model_with_metadata" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "model_with_metadata/\n", + "model_with_metadata/mirnet_int8.tflite\n", + "model_with_metadata/mirnet_int8.json\n", + "model_with_metadata/mirnet_fp16.json\n", + "model_with_metadata/mirnet_fp16.tflite\n", + "model_with_metadata/mirnet_dr.json\n", + "model_with_metadata/mirnet_dr.tflite\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file