From cade740a3c202aff39b0c62f0b06e77ed30bbb50 Mon Sep 17 00:00:00 2001 From: Philip Colangelo Date: Wed, 18 Dec 2024 12:12:03 -0500 Subject: [PATCH] further support for ingesting pytorch --- examples/analysis.py | 2 +- src/digest/main.py | 38 +++- .../model_class/digest_pytorch_model.py | 21 ++- src/digest/multi_model_selection_page.py | 2 +- src/digest/pytorch_ingest.py | 173 +++++++++--------- src/digest/ui/pytorchingest.ui | 11 +- src/digest/ui/pytorchingest_ui.py | 4 +- test/test_gui.py | 8 +- test/test_reports.py | 2 +- 9 files changed, 153 insertions(+), 108 deletions(-) diff --git a/examples/analysis.py b/examples/analysis.py index da89068..0910637 100644 --- a/examples/analysis.py +++ b/examples/analysis.py @@ -73,7 +73,7 @@ def main(onnx_files: str, output_dir: str): print(f"dim: {dynamic_shape}") digest_model = DigestOnnxModel( - model_proto, onnx_filepath=onnx_file, model_name=model_name + model_proto, onnx_file_path=onnx_file, model_name=model_name ) # Update the global model dictionary diff --git a/src/digest/main.py b/src/digest/main.py index 333620c..b11850d 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -287,7 +287,7 @@ def closeTab(self, index): # delete the digest model to free up used memory if unique_id in self.digest_models: - del self.digest_models[unique_id] + self.digest_models.pop(unique_id) self.ui.tabWidget.removeTab(index) if self.ui.tabWidget.count() == 0: @@ -486,20 +486,41 @@ def load_onnx(self, file_path: str): # Every time an onnx is loaded we should emulate a model summary button click self.summary_clicked() - # Before opening the file, check to see if it is already opened. + model_proto = None + + # Before opening the ONNX file, check to see if it is already opened. for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and file_path == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return + if ( + isinstance(widget, modelSummary) + and isinstance(widget.digest_model, DigestOnnxModel) + and file_path == widget.file + ): + # Check if the model proto is different + if widget.digest_model.model_proto: + model_proto = onnx_utils.load_onnx( + file_path, load_external_data=False + ) + # If they are equivalent, set the GUI to show the existing + # report and return + if model_proto == widget.digest_model.model_proto: + self.ui.tabWidget.setCurrentIndex(index) + return + # If they aren't equivalent, then the proto has been modified. In this case, + # we close the tab associated with the stale model, remove from the model list, + # then go through the standard process of adding it to the tabWidget. In the + # future, it may be slightly better to have an update tab function. + else: + self.closeTab(index) try: progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self) QApplication.processEvents() # Process pending events - model = onnx_utils.load_onnx(file_path, load_external_data=False) - opt_model, opt_passed = onnx_utils.optimize_onnx_model(model) + if not model_proto: + model_proto = onnx_utils.load_onnx(file_path, load_external_data=False) + opt_model, opt_passed = onnx_utils.optimize_onnx_model(model_proto) progress.step() basename = os.path.splitext(os.path.basename(file_path)) @@ -918,6 +939,9 @@ def load_pytorch(self, file_path: str): basename = os.path.splitext(os.path.basename(file_path)) model_name = basename[0] + # The current support for PyTorch includes exporting it to ONNX. In this case, + # an ingest window will pop up giving the user options to export. This window + # will block the main GUI until the ingest window is closed self.pytorch_ingest = PyTorchIngest(file_path, model_name) self.pytorch_ingest_window = PopupDialog( self.pytorch_ingest, "PyTorch Ingest", self diff --git a/src/digest/model_class/digest_pytorch_model.py b/src/digest/model_class/digest_pytorch_model.py index 68b1a76..9f159e9 100644 --- a/src/digest/model_class/digest_pytorch_model.py +++ b/src/digest/model_class/digest_pytorch_model.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict -from typing import List, Tuple, Optional, Any, Union +from typing import List, Tuple, Optional, Union import inspect import onnx import torch @@ -37,7 +37,9 @@ def __init__( # Input dictionary to contain the names and shapes # required for exporting the ONNX model - self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict() + self.input_tensor_info: OrderedDict[ + str, Tuple[torch.dtype, List[Union[str, int]]] + ] = OrderedDict() self.pytorch_model = torch.load(pytorch_file_path) @@ -58,21 +60,24 @@ def save_yaml_report(self, file_path: str) -> None: def save_text_report(self, file_path: str) -> None: """This will be done in the DigestOnnxModel""" - def generate_random_tensor(self, shape: List[Union[str, int]]): + def generate_random_tensor(self, dtype: torch.dtype, shape: List[Union[str, int]]): static_shape = [dim if isinstance(dim, int) else 1 for dim in shape] - return torch.rand(static_shape) + if dtype in (torch.float16, torch.float32, torch.float64): + return torch.rand(static_shape, dtype=dtype) + else: + return torch.randint(0, 100, static_shape, dtype=dtype) def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]: dummy_input_names: List[str] = list(self.input_tensor_info.keys()) dummy_inputs: List[torch.Tensor] = [] - for shape in self.input_tensor_info.values(): - dummy_inputs.append(self.generate_random_tensor(shape)) + for dtype, shape in self.input_tensor_info.values(): + dummy_inputs.append(self.generate_random_tensor(dtype, shape)) dynamic_axes = { name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)} - for name, shape in self.input_tensor_info.items() + for name, (_, shape) in self.input_tensor_info.items() } try: @@ -92,7 +97,7 @@ def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]: return onnx.load(output_onnx_path) - except (TypeError, RuntimeError) as err: + except (ValueError, TypeError, RuntimeError) as err: print(f"Failed to export ONNX: {err}") raise diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index e9d5c2b..3290083 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -56,7 +56,7 @@ def run(self): model_proto = onnx_utils.load_onnx(file, False) self.model_dict[file] = DigestOnnxModel( model_proto, - onnx_filepath=file, + onnx_file_path=file, model_name=model_name, save_proto=False, ) diff --git a/src/digest/pytorch_ingest.py b/src/digest/pytorch_ingest.py index 0ddd802..4f3d8cf 100644 --- a/src/digest/pytorch_ingest.py +++ b/src/digest/pytorch_ingest.py @@ -2,8 +2,9 @@ import os from collections import OrderedDict -from typing import Optional, Callable, Union +from typing import Optional, Callable, Union, List from platformdirs import user_cache_dir +import torch # pylint: disable=no-name-in-module from PySide6.QtWidgets import ( @@ -14,6 +15,7 @@ QFormLayout, QFileDialog, QHBoxLayout, + QComboBox, ) from PySide6.QtGui import QFont from PySide6.QtCore import Qt, Signal @@ -25,8 +27,23 @@ DigestPyTorchModel, ) - -class UserInputFormWithInfo: +torch_tensor_types = { + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.uint8": torch.uint8, + "torch.uint16": torch.uint16, + "torch.uint32": torch.uint32, + "torch.uint64": torch.uint64, + "torch.int8": torch.int8, + "torch.int16": torch.int16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.bool": torch.bool, +} + + +class UserModelInputsForm: def __init__(self, form_layout: QFormLayout): self.form_layout = form_layout self.num_rows = 0 @@ -34,79 +51,89 @@ def __init__(self, form_layout: QFormLayout): def add_row( self, label_text: str, - edit_text: str, text_width: int, - info_text: str, edit_finished_fnc: Optional[Callable] = None, ) -> int: + # The label displays the tensor name font = QFont("Inter", 10) label = QLabel(f"{label_text}:") label.setContentsMargins(0, 0, 0, 0) label.setFont(font) + # The combo box enables users to specify the tensor data type + dtype_combo_box = QComboBox() + for tensor_type in torch_tensor_types.keys(): + dtype_combo_box.addItem(tensor_type) + dtype_combo_box.setCurrentIndex(1) # float32 by default + dtype_combo_box.currentIndexChanged.connect(edit_finished_fnc) + + # Line edit is where the user specifies the tensor shape line_edit = QLineEdit() line_edit.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) line_edit.setMinimumWidth(text_width) line_edit.setMinimumHeight(20) - line_edit.setText(edit_text) + line_edit.setPlaceholderText("Set tensor shape here") if edit_finished_fnc: line_edit.editingFinished.connect(edit_finished_fnc) - info_label = QLabel() - info_label.setText(info_text) - font = QFont("Arial", 10, italic=True) - info_label.setFont(font) - info_label.setContentsMargins(10, 0, 0, 0) - row_layout = QHBoxLayout() row_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) row_layout.setSpacing(5) row_layout.setObjectName(f"row{self.num_rows}_layout") row_layout.addWidget(label, alignment=Qt.AlignmentFlag.AlignHCenter) + row_layout.addWidget(dtype_combo_box, alignment=Qt.AlignmentFlag.AlignHCenter) row_layout.addWidget(line_edit, alignment=Qt.AlignmentFlag.AlignHCenter) - row_layout.addWidget(info_label, alignment=Qt.AlignmentFlag.AlignHCenter) self.num_rows += 1 self.form_layout.addRow(row_layout) return self.num_rows - def get_row_label(self, row_idx: int) -> str: + def get_row_tensor_name(self, row_idx: int) -> str: form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) - if form_item: - row_layout = form_item.layout() - if isinstance(row_layout, QHBoxLayout): - line_edit_item = row_layout.itemAt(0) - if line_edit_item: - line_edit_widget = line_edit_item.widget() - if isinstance(line_edit_widget, QLabel): - return line_edit_widget.text() - return "" - - def get_row_line_edit(self, row_idx: int) -> str: + row_layout = form_item.layout() + assert isinstance(row_layout, QHBoxLayout) + line_edit_item = row_layout.itemAt(0) + line_edit_widget = line_edit_item.widget() + assert isinstance(line_edit_widget, QLabel) + return line_edit_widget.text().split(":")[0] + + def get_row_tensor_dtype(self, row_idx: int) -> torch.dtype: form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) - if form_item: - row_layout = form_item.layout() - if isinstance(row_layout, QHBoxLayout): - line_edit_item = row_layout.itemAt(1) - if line_edit_item: - line_edit_widget = line_edit_item.widget() - if isinstance(line_edit_widget, QLineEdit): - return line_edit_widget.text() - return "" - - def get_row_line_edit_widget(self, row_idx: int) -> Union[QLineEdit, None]: + row_layout = form_item.layout() + combo_box = row_layout.itemAt(1) + assert combo_box, "The combo box was not found which is unexpected!" + combo_box_widget = combo_box.widget() + assert isinstance(combo_box_widget, QComboBox) + return torch_tensor_types[combo_box_widget.currentText()] + + def get_row_tensor_shape(self, row_idx: int) -> List[Union[str, int]]: + shape_widget = self.get_row_tensor_shape_widget(row_idx) + shape_str = shape_widget.text() + shape_list: List[Union[str, int]] = [] + if not shape_str: + return shape_list + shape_list_str = shape_str.split(",") + + for dim in shape_list_str: + dim = dim.strip() + # Integer based shape + if all(char.isdigit() for char in dim): + shape_list.append(int(dim)) + # Symbolic shape + else: + shape_list.append(dim) + return shape_list + + def get_row_tensor_shape_widget(self, row_idx: int) -> QLineEdit: form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) - if form_item: - row_layout = form_item.layout() - if isinstance(row_layout, QHBoxLayout): - line_edit_item = row_layout.itemAt(1) - if line_edit_item: - line_edit_widget = line_edit_item.widget() - if isinstance(line_edit_widget, QLineEdit): - return line_edit_widget - return None + row_layout = form_item.layout() + line_edit_item = row_layout.itemAt(2) + assert line_edit_item + line_edit_widget = line_edit_item.widget() + assert isinstance(line_edit_widget, QLineEdit) + return line_edit_widget class PyTorchIngest(QWidget): @@ -156,7 +183,7 @@ def __init__( self.ui.exportParamsCheckBox.isChecked() ) - self.user_input_form = UserInputFormWithInfo(self.ui.inputsFormLayout) + self.user_input_form = UserModelInputsForm(self.ui.inputsFormLayout) # Set up the opset form self.lowest_supported_opset = 7 # this requirement came from pytorch @@ -173,10 +200,8 @@ def __init__( for val in self.fwd_parameters.values(): self.user_input_form.add_row( str(val), - "", 250, - "", - self.update_input_shape, + self.update_tensor_info, ) def set_widget_invalid(self, widget: QWidget): @@ -226,41 +251,25 @@ def update_opset_version(self): self.digest_pytorch_model.opset = opset_text_item self.set_widget_valid(self.ui.opsetLineEdit) - def update_input_shape(self): + def update_tensor_info(self): """Because this is an external function to the UserInputFormWithInfo class we go through each input everytime there is an update.""" for row_idx in range(self.user_input_form.form_layout.rowCount()): - label_text = self.user_input_form.get_row_label(row_idx) - line_edit_text = self.user_input_form.get_row_line_edit(row_idx) - if label_text and line_edit_text: - tensor_name = label_text.split(":")[0] - if tensor_name in self.digest_pytorch_model.input_tensor_info: - self.digest_pytorch_model.input_tensor_info[tensor_name].clear() - else: - self.digest_pytorch_model.input_tensor_info[tensor_name] = [] - shape_list = line_edit_text.split(",") - try: - for dim in shape_list: - dim = dim.strip() - # Integer based shape - if all(char.isdigit() for char in dim): - self.digest_pytorch_model.input_tensor_info[ - tensor_name - ].append(int(dim)) - # Symbolic shape - else: - self.digest_pytorch_model.input_tensor_info[ - tensor_name - ].append(dim) - except ValueError as err: - print(f"Malformed shape: {err}") - widget = self.user_input_form.get_row_line_edit_widget(row_idx) - if widget: - self.set_widget_invalid(widget) - else: - widget = self.user_input_form.get_row_line_edit_widget(row_idx) - if widget: - self.set_widget_valid(widget) + widget = self.user_input_form.get_row_tensor_shape_widget(row_idx) + tensor_name = self.user_input_form.get_row_tensor_name(row_idx) + tensor_dtype = self.user_input_form.get_row_tensor_dtype(row_idx) + try: + tensor_shape = self.user_input_form.get_row_tensor_shape(row_idx) + except ValueError as err: + print(f"Shape invalid: {err}") + self.set_widget_invalid(widget) + else: + if tensor_name and tensor_shape: + self.set_widget_valid(widget) + self.digest_pytorch_model.input_tensor_info[tensor_name] = ( + tensor_dtype, + tensor_shape, + ) def export_onnx(self): onnx_file_path = os.path.join( @@ -268,7 +277,7 @@ def export_onnx(self): ) try: self.digest_pytorch_model.export_to_onnx(onnx_file_path) - except (TypeError, RuntimeError) as err: + except (ValueError, TypeError, RuntimeError) as err: self.ui.exportWarningLabel.setText(f"Failed to export ONNX: {err}") self.ui.exportWarningLabel.show() else: diff --git a/src/digest/ui/pytorchingest.ui b/src/digest/ui/pytorchingest.ui index abc6a5e..6be230b 100644 --- a/src/digest/ui/pytorchingest.ui +++ b/src/digest/ui/pytorchingest.ui @@ -278,8 +278,7 @@ - - + Export Options @@ -466,7 +465,7 @@ color: lightgrey; - The following inputs were taken from the PyTorch model's forward function. Please set the dimensions for each input needed. Dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244. + The following inputs were taken from the PyTorch model's forward function. Please set the type and dimensions for each required input. Shape dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244. true @@ -478,6 +477,12 @@ + + 10 + + + 10 + 20 diff --git a/src/digest/ui/pytorchingest_ui.py b/src/digest/ui/pytorchingest_ui.py index c9a761e..f658051 100644 --- a/src/digest/ui/pytorchingest_ui.py +++ b/src/digest/ui/pytorchingest_ui.py @@ -286,6 +286,8 @@ def setupUi(self, pytorchIngest): self.inputsFormLayout = QFormLayout() self.inputsFormLayout.setObjectName(u"inputsFormLayout") + self.inputsFormLayout.setHorizontalSpacing(10) + self.inputsFormLayout.setVerticalSpacing(10) self.inputsFormLayout.setContentsMargins(20, -1, -1, -1) self.verticalLayout_3.addLayout(self.inputsFormLayout) @@ -351,7 +353,7 @@ def retranslateUi(self, pytorchIngest): self.opsetInfoLabel.setText(QCoreApplication.translate("pytorchIngest", u"(accepted range is 7 - 21):", None)) self.opsetLineEdit.setText(QCoreApplication.translate("pytorchIngest", u"17", None)) self.inputsGroupBox.setTitle(QCoreApplication.translate("pytorchIngest", u"Inputs", None)) - self.label.setText(QCoreApplication.translate("pytorchIngest", u"The following inputs were taken from the PyTorch model's forward function. Please set the dimensions for each input needed. Dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244.", None)) + self.label.setText(QCoreApplication.translate("pytorchIngest", u"The following inputs were taken from the PyTorch model's forward function. Please set the type and dimensions for each required input. Shape dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244.", None)) self.exportWarningLabel.setText(QCoreApplication.translate("pytorchIngest", u"

This is a warning message that we can use for now to prompt the user.

", None)) self.exportOnnxBtn.setText(QCoreApplication.translate("pytorchIngest", u"Export ONNX", None)) # retranslateUi diff --git a/test/test_gui.py b/test/test_gui.py index 0308ec7..9a06f3e 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -10,7 +10,7 @@ # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt, QTimer, QEventLoop +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication import digest.main @@ -128,13 +128,13 @@ def test_open_valid_pytorch(self): pytorch_ingest = PyTorchIngest(pt_file_path, digest_model.model_name) pytorch_ingest.show() - input_shape_edit = pytorch_ingest.user_input_form.get_row_line_edit_widget( - 0 + input_shape_edit = ( + pytorch_ingest.user_input_form.get_row_tensor_shape_widget(0) ) assert input_shape_edit input_shape_edit.setText("batch_size, 3, 224, 224") - pytorch_ingest.update_input_shape() + pytorch_ingest.update_tensor_info() with patch( "PySide6.QtWidgets.QFileDialog.getExistingDirectory" diff --git a/test/test_reports.py b/test/test_reports.py index ae99ab9..e4d327e 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -56,7 +56,7 @@ def test_against_example_reports(self): opt_model, _ = onnx_utils.optimize_onnx_model(model_proto) digest_model = DigestOnnxModel( opt_model, - onnx_filepath=TEST_ONNX, + onnx_file_path=TEST_ONNX, model_name=model_name, save_proto=False, )