Skip to content

Commit

Permalink
Metadata notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 5, 2020
1 parent 0166d9d commit 34e3cc7
Showing 1 changed file with 329 additions and 0 deletions.
329 changes: 329 additions & 0 deletions Add_Metadata.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Add_Metadata",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sayakpaul/MIRNet-TFLite/blob/main/Add_Metadata.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "laCUI2hzapKv"
},
"source": [
"- Reference: https://github.com/margaretmz/selfie2anime-with-tflite/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb. \n",
"- TensorFlow Lite meatdata: https://www.tensorflow.org/lite/convert/metadata.\n",
"- Authored by: Sayak.\n",
"- Updated on: December 05 2020."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lIYdn1woOS1n",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bea2e589-410d-4957-f388-3b37ad7733d8"
},
"source": [
"!pip install -q tflite-support"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 1.0MB 5.6MB/s \n",
"\u001b[K |████████████████████████████████| 194kB 12.5MB/s \n",
"\u001b[?25h"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hSBlosV-Wim7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d73a0356-b4f9-44b1-a310-610aa81a4e0d"
},
"source": [
"import os\n",
"import tensorflow as tf\n",
"from absl import flags\n",
"\n",
"from tflite_support import flatbuffers\n",
"from tflite_support import metadata as _metadata\n",
"from tflite_support import metadata_schema_py_generated as _metadata_fb\n",
"\n",
"print(tf.__version__)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"2.3.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "v_czMsARW1SW"
},
"source": [
"!mkdir model_without_metadata\n",
"!mkdir model_with_metadata"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8_Qa9lo4W6zV"
},
"source": [
"!wget -q https://github.com/sayakpaul/MIRNet-TFLite/releases/download/v0.1.0/fixed_shape.zip\n",
"!unzip -q fixed_shape.zip\n",
"\n",
"!mv fixed_shape/*.tflite model_without_metadata/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "InH13DM9cdqP"
},
"source": [
"# This is where we will export a new .tflite model file with metadata, and a .json file with metadata info\n",
"EXPORT_DIR = \"model_with_metadata\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "E4ta4ByLYST3"
},
"source": [
"class MetadataPopulatorForMIRNet(object):\n",
" \"\"\"Populates the metadata for the MIRNet model.\"\"\"\n",
"\n",
" def __init__(self, model_file):\n",
" self.model_file = model_file\n",
" self.metadata_buf = None\n",
"\n",
" def populate(self):\n",
" \"\"\"Creates metadata and then populates it for a low-light image enhancement model.\"\"\"\n",
" self._create_metadata()\n",
" self._populate_metadata()\n",
" \n",
" def _create_metadata(self):\n",
" \"\"\"Creates the metadata for the CartoonGAN model.\"\"\"\n",
"\n",
" # Creates model info.\n",
" model_meta = _metadata_fb.ModelMetadataT()\n",
" model_meta.name = \"MIRNet\" \n",
" model_meta.description = (\"Enhances a low-light image. Reference: https://arxiv.org/pdf/2003.06792v2.pdf. TFLiteConverter used from TensorFlow 2.3.0.\")\n",
" model_meta.version = \"v1\"\n",
" model_meta.author = \"Sayak\"\n",
" model_meta.license = (\"Apache License. Version 2.0 \"\n",
" \"http://www.apache.org/licenses/LICENSE-2.0.\")\n",
"\n",
" # Creates info for the input, normal image.\n",
" input_image_meta = _metadata_fb.TensorMetadataT()\n",
" input_image_meta.name = \"low_light_image\"\n",
" # if self.model_type==\"other\":\n",
" input_image_meta.description = (\n",
" \"The expected image is 400 x 400, with three channels \"\n",
" \"(red, blue, and green) per pixel. Each value in the tensor is between\"\n",
" \" 0 and 1.\")\n",
" input_image_meta.content = _metadata_fb.ContentT()\n",
" input_image_meta.content.contentProperties = (\n",
" _metadata_fb.ImagePropertiesT())\n",
" input_image_meta.content.contentProperties.colorSpace = (\n",
" _metadata_fb.ColorSpaceType.RGB)\n",
" input_image_meta.content.contentPropertiesType = (\n",
" _metadata_fb.ContentProperties.ImageProperties)\n",
" input_image_normalization = _metadata_fb.ProcessUnitT()\n",
" input_image_normalization.optionsType = (\n",
" _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n",
" input_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n",
" input_image_normalization.options.mean = [0.0]\n",
" input_image_normalization.options.std = [255.0]\n",
" input_image_meta.processUnits = [input_image_normalization]\n",
" input_image_stats = _metadata_fb.StatsT()\n",
" input_image_stats.max = [1.0]\n",
" input_image_stats.min = [0.0]\n",
" input_image_meta.stats = input_image_stats\n",
"\n",
"\n",
" # Creates output info, cartoonized image\n",
" output_image_meta = _metadata_fb.TensorMetadataT()\n",
" output_image_meta.name = \"enhanced_image\"\n",
" output_image_meta.description = \"Image enhanced.\"\n",
" output_image_meta.content = _metadata_fb.ContentT()\n",
" output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()\n",
" output_image_meta.content.contentProperties.colorSpace = (\n",
" _metadata_fb.ColorSpaceType.RGB)\n",
" output_image_meta.content.contentPropertiesType = (\n",
" _metadata_fb.ContentProperties.ImageProperties)\n",
" output_image_normalization = _metadata_fb.ProcessUnitT()\n",
" output_image_normalization.optionsType = (\n",
" _metadata_fb.ProcessUnitOptions.NormalizationOptions)\n",
" output_image_normalization.options = _metadata_fb.NormalizationOptionsT()\n",
" output_image_normalization.options.mean = [0.0]\n",
" output_image_normalization.options.std = [1.0] \n",
" output_image_meta.processUnits = [output_image_normalization]\n",
" output_image_stats = _metadata_fb.StatsT()\n",
" output_image_stats.max = [255.0]\n",
" output_image_stats.min = [0.0]\n",
" output_image_meta.stats = output_image_stats\n",
"\n",
" # Creates subgraph info.\n",
" subgraph = _metadata_fb.SubGraphMetadataT()\n",
" subgraph.inputTensorMetadata = [input_image_meta] \n",
" subgraph.outputTensorMetadata = [output_image_meta] \n",
" model_meta.subgraphMetadata = [subgraph]\n",
"\n",
" b = flatbuffers.Builder(0)\n",
" b.Finish(\n",
" model_meta.Pack(b),\n",
" _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)\n",
" self.metadata_buf = b.Output()\n",
"\n",
" def _populate_metadata(self):\n",
" \"\"\"Populates metadata to the model file.\"\"\"\n",
" populator = _metadata.MetadataPopulator.with_model_file(self.model_file)\n",
" populator.load_metadata_buffer(self.metadata_buf)\n",
" populator.populate()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vRlUqyFMcNyy"
},
"source": [
"def populate_metadata(model_file):\n",
" \"\"\"Populates the metadata using the populator specified.\n",
" Args:\n",
" model_file: valid path to the model file.\n",
" \"\"\"\n",
"\n",
" # Populates metadata for the model.\n",
" model_file_basename = os.path.basename(model_file)\n",
" export_path = os.path.join(EXPORT_DIR, model_file_basename)\n",
" tf.io.gfile.copy(model_file, export_path, overwrite=True)\n",
"\n",
" populator = MetadataPopulatorForMIRNet(export_path) \n",
" populator.populate()\n",
"\n",
" # Displays the metadata that was just populated into the tflite model.\n",
" displayer = _metadata.MetadataDisplayer.with_model_file(export_path)\n",
" export_json_file = os.path.join(\n",
" EXPORT_DIR,\n",
" os.path.splitext(model_file_basename)[0] + \".json\")\n",
" json_file = displayer.get_metadata_json()\n",
" with open(export_json_file, \"w\") as f:\n",
" f.write(json_file)\n",
" print(\"Finished populating metadata and associated file to the model:\")\n",
" print(export_path)\n",
" print(\"The metadata json file has been saved to:\")\n",
" print(os.path.join(EXPORT_DIR,\n",
" os.path.splitext(model_file_basename)[0] + \".json\"))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YVSJFR4HcUuI",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a2cbe140-abf7-4541-e663-329b156d50c5"
},
"source": [
"quantization = \"int8\" #@param [\"dr\", \"int8\", \"fp16\"]\n",
"tflite_model_path = f\"mirnet_{quantization}.tflite\" \n",
"MODEL_FILE = \"/content/model_without_metadata/{}\".format(tflite_model_path)\n",
"populate_metadata(MODEL_FILE)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Finished populating metadata and associated file to the model:\n",
"model_with_metadata/mirnet_int8.tflite\n",
"The metadata json file has been saved to:\n",
"model_with_metadata/mirnet_int8.json\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Bcn0-gbW3ep9",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "8d3e9e86-f5ce-41b1-c6b1-5e55baf2fedb"
},
"source": [
"!tar cvf model_with_metadata.tar.gz model_with_metadata"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"model_with_metadata/\n",
"model_with_metadata/mirnet_int8.tflite\n",
"model_with_metadata/mirnet_int8.json\n",
"model_with_metadata/mirnet_fp16.json\n",
"model_with_metadata/mirnet_fp16.tflite\n",
"model_with_metadata/mirnet_dr.json\n",
"model_with_metadata/mirnet_dr.tflite\n"
],
"name": "stdout"
}
]
}
]
}

0 comments on commit 34e3cc7

Please sign in to comment.