-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "Add_Metadata", | ||
"provenance": [], | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/sayakpaul/MIRNet-TFLite/blob/main/Add_Metadata.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "laCUI2hzapKv" | ||
}, | ||
"source": [ | ||
"- Reference: https://github.com/margaretmz/selfie2anime-with-tflite/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb. \n", | ||
"- TensorFlow Lite meatdata: https://www.tensorflow.org/lite/convert/metadata.\n", | ||
"- Authored by: Sayak.\n", | ||
"- Updated on: December 05 2020." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "lIYdn1woOS1n", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "bea2e589-410d-4957-f388-3b37ad7733d8" | ||
}, | ||
"source": [ | ||
"!pip install -q tflite-support" | ||
], | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[K |████████████████████████████████| 1.0MB 5.6MB/s \n", | ||
"\u001b[K |████████████████████████████████| 194kB 12.5MB/s \n", | ||
"\u001b[?25h" | ||
], | ||
"name": "stdout" | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "hSBlosV-Wim7", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "d73a0356-b4f9-44b1-a310-610aa81a4e0d" | ||
}, | ||
"source": [ | ||
"import os\n", | ||
"import tensorflow as tf\n", | ||
"from absl import flags\n", | ||
"\n", | ||
"from tflite_support import flatbuffers\n", | ||
"from tflite_support import metadata as _metadata\n", | ||
"from tflite_support import metadata_schema_py_generated as _metadata_fb\n", | ||
"\n", | ||
"print(tf.__version__)" | ||
], | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"text": [ | ||
"2.3.0\n" | ||
], | ||
"name": "stdout" | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "v_czMsARW1SW" | ||
}, | ||
"source": [ | ||
"!mkdir model_without_metadata\n", | ||
"!mkdir model_with_metadata" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "8_Qa9lo4W6zV" | ||
}, | ||
"source": [ | ||
"!wget -q https://github.com/sayakpaul/MIRNet-TFLite/releases/download/v0.1.0/fixed_shape.zip\n", | ||
"!unzip -q fixed_shape.zip\n", | ||
"\n", | ||
"!mv fixed_shape/*.tflite model_without_metadata/" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "InH13DM9cdqP" | ||
}, | ||
"source": [ | ||
"# This is where we will export a new .tflite model file with metadata, and a .json file with metadata info\n", | ||
"EXPORT_DIR = \"model_with_metadata\"" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "E4ta4ByLYST3" | ||
}, | ||
"source": [ | ||
"class MetadataPopulatorForMIRNet(object):\n", | ||
" \"\"\"Populates the metadata for the MIRNet model.\"\"\"\n", | ||
"\n", | ||
" def __init__(self, model_file):\n", | ||
" self.model_file = model_file\n", | ||
" self.metadata_buf = None\n", | ||
"\n", | ||
" def populate(self):\n", | ||
" \"\"\"Creates metadata and then populates it for a low-light image enhancement model.\"\"\"\n", | ||
" self._create_metadata()\n", | ||
" self._populate_metadata()\n", | ||
" \n", | ||
" def _create_metadata(self):\n", | ||
" \"\"\"Creates the metadata for the CartoonGAN model.\"\"\"\n", | ||
"\n", | ||
" # Creates model info.\n", | ||
" model_meta = _metadata_fb.ModelMetadataT()\n", | ||
" model_meta.name = \"MIRNet\" \n", | ||
" model_meta.description = (\"Enhances a low-light image. Reference: https://arxiv.org/pdf/2003.06792v2.pdf. TFLiteConverter used from TensorFlow 2.3.0.\")\n", | ||
" model_meta.version = \"v1\"\n", | ||
" model_meta.author = \"Sayak\"\n", | ||
" model_meta.license = (\"Apache License. Version 2.0 \"\n", | ||
" \"http://www.apache.org/licenses/LICENSE-2.0.\")\n", | ||
"\n", | ||
" # Creates info for the input, normal image.\n", | ||
" input_image_meta = _metadata_fb.TensorMetadataT()\n", | ||
" input_image_meta.name = \"low_light_image\"\n", | ||
" # if self.model_type==\"other\":\n", | ||
" input_image_meta.description = (\n", | ||
" \"The expected image is 400 x 400, with three channels \"\n", | ||
" \"(red, blue, and green) per pixel. Each value in the tensor is between\"\n", | ||
" \" 0 and 1.\")\n", | ||
" input_image_meta.content = _metadata_fb.ContentT()\n", | ||
" input_image_meta.content.contentProperties = (\n", | ||
" _metadata_fb.ImagePropertiesT())\n", | ||
" input_image_meta.content.contentProperties.colorSpace = (\n", | ||
" _metadata_fb.ColorSpaceType.RGB)\n", | ||
" input_image_meta.content.contentPropertiesType = (\n", | ||
" _metadata_fb.ContentProperties.ImageProperties)\n", | ||
" input_image_normalization = _metadata_fb.ProcessUnitT()\n", | ||
" input_image_normalization.optionsType = (\n", | ||
" _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n", | ||
" input_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n", | ||
" input_image_normalization.options.mean = [0.0]\n", | ||
" input_image_normalization.options.std = [255.0]\n", | ||
" input_image_meta.processUnits = [input_image_normalization]\n", | ||
" input_image_stats = _metadata_fb.StatsT()\n", | ||
" input_image_stats.max = [1.0]\n", | ||
" input_image_stats.min = [0.0]\n", | ||
" input_image_meta.stats = input_image_stats\n", | ||
"\n", | ||
"\n", | ||
" # Creates output info, cartoonized image\n", | ||
" output_image_meta = _metadata_fb.TensorMetadataT()\n", | ||
" output_image_meta.name = \"enhanced_image\"\n", | ||
" output_image_meta.description = \"Image enhanced.\"\n", | ||
" output_image_meta.content = _metadata_fb.ContentT()\n", | ||
" output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()\n", | ||
" output_image_meta.content.contentProperties.colorSpace = (\n", | ||
" _metadata_fb.ColorSpaceType.RGB)\n", | ||
" output_image_meta.content.contentPropertiesType = (\n", | ||
" _metadata_fb.ContentProperties.ImageProperties)\n", | ||
" output_image_normalization = _metadata_fb.ProcessUnitT()\n", | ||
" output_image_normalization.optionsType = (\n", | ||
" _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n", | ||
" output_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n", | ||
" output_image_normalization.options.mean = [0.0]\n", | ||
" output_image_normalization.options.std = [1.0] \n", | ||
" output_image_meta.processUnits = [output_image_normalization]\n", | ||
" output_image_stats = _metadata_fb.StatsT()\n", | ||
" output_image_stats.max = [255.0]\n", | ||
" output_image_stats.min = [0.0]\n", | ||
" output_image_meta.stats = output_image_stats\n", | ||
"\n", | ||
" # Creates subgraph info.\n", | ||
" subgraph = _metadata_fb.SubGraphMetadataT()\n", | ||
" subgraph.inputTensorMetadata = [input_image_meta] \n", | ||
" subgraph.outputTensorMetadata = [output_image_meta] \n", | ||
" model_meta.subgraphMetadata = [subgraph]\n", | ||
"\n", | ||
" b = flatbuffers.Builder(0)\n", | ||
" b.Finish(\n", | ||
" model_meta.Pack(b),\n", | ||
" _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)\n", | ||
" self.metadata_buf = b.Output()\n", | ||
"\n", | ||
" def _populate_metadata(self):\n", | ||
" \"\"\"Populates metadata to the model file.\"\"\"\n", | ||
" populator = _metadata.MetadataPopulator.with_model_file(self.model_file)\n", | ||
" populator.load_metadata_buffer(self.metadata_buf)\n", | ||
" populator.populate()" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "vRlUqyFMcNyy" | ||
}, | ||
"source": [ | ||
"def populate_metadata(model_file):\n", | ||
" \"\"\"Populates the metadata using the populator specified.\n", | ||
" Args:\n", | ||
" model_file: valid path to the model file.\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" # Populates metadata for the model.\n", | ||
" model_file_basename = os.path.basename(model_file)\n", | ||
" export_path = os.path.join(EXPORT_DIR, model_file_basename)\n", | ||
" tf.io.gfile.copy(model_file, export_path, overwrite=True)\n", | ||
"\n", | ||
" populator = MetadataPopulatorForMIRNet(export_path) \n", | ||
" populator.populate()\n", | ||
"\n", | ||
" # Displays the metadata that was just populated into the tflite model.\n", | ||
" displayer = _metadata.MetadataDisplayer.with_model_file(export_path)\n", | ||
" export_json_file = os.path.join(\n", | ||
" EXPORT_DIR,\n", | ||
" os.path.splitext(model_file_basename)[0] + \".json\")\n", | ||
" json_file = displayer.get_metadata_json()\n", | ||
" with open(export_json_file, \"w\") as f:\n", | ||
" f.write(json_file)\n", | ||
" print(\"Finished populating metadata and associated file to the model:\")\n", | ||
" print(export_path)\n", | ||
" print(\"The metadata json file has been saved to:\")\n", | ||
" print(os.path.join(EXPORT_DIR,\n", | ||
" os.path.splitext(model_file_basename)[0] + \".json\"))" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "YVSJFR4HcUuI", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "a2cbe140-abf7-4541-e663-329b156d50c5" | ||
}, | ||
"source": [ | ||
"quantization = \"int8\" #@param [\"dr\", \"int8\", \"fp16\"]\n", | ||
"tflite_model_path = f\"mirnet_{quantization}.tflite\" \n", | ||
"MODEL_FILE = \"/content/model_without_metadata/{}\".format(tflite_model_path)\n", | ||
"populate_metadata(MODEL_FILE)" | ||
], | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"text": [ | ||
"Finished populating metadata and associated file to the model:\n", | ||
"model_with_metadata/mirnet_int8.tflite\n", | ||
"The metadata json file has been saved to:\n", | ||
"model_with_metadata/mirnet_int8.json\n" | ||
], | ||
"name": "stdout" | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "Bcn0-gbW3ep9", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "8d3e9e86-f5ce-41b1-c6b1-5e55baf2fedb" | ||
}, | ||
"source": [ | ||
"!tar cvf model_with_metadata.tar.gz model_with_metadata" | ||
], | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"text": [ | ||
"model_with_metadata/\n", | ||
"model_with_metadata/mirnet_int8.tflite\n", | ||
"model_with_metadata/mirnet_int8.json\n", | ||
"model_with_metadata/mirnet_fp16.json\n", | ||
"model_with_metadata/mirnet_fp16.tflite\n", | ||
"model_with_metadata/mirnet_dr.json\n", | ||
"model_with_metadata/mirnet_dr.tflite\n" | ||
], | ||
"name": "stdout" | ||
} | ||
] | ||
} | ||
] | ||
} |