Skip to content

Commit

Permalink
further support for ingesting pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Jan 3, 2025
1 parent 1573a71 commit cade740
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 108 deletions.
2 changes: 1 addition & 1 deletion examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main(onnx_files: str, output_dir: str):
print(f"dim: {dynamic_shape}")

digest_model = DigestOnnxModel(
model_proto, onnx_filepath=onnx_file, model_name=model_name
model_proto, onnx_file_path=onnx_file, model_name=model_name
)

# Update the global model dictionary
Expand Down
38 changes: 31 additions & 7 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def closeTab(self, index):

# delete the digest model to free up used memory
if unique_id in self.digest_models:
del self.digest_models[unique_id]
self.digest_models.pop(unique_id)

self.ui.tabWidget.removeTab(index)
if self.ui.tabWidget.count() == 0:
Expand Down Expand Up @@ -486,20 +486,41 @@ def load_onnx(self, file_path: str):
# Every time an onnx is loaded we should emulate a model summary button click
self.summary_clicked()

# Before opening the file, check to see if it is already opened.
model_proto = None

# Before opening the ONNX file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary) and file_path == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return
if (
isinstance(widget, modelSummary)
and isinstance(widget.digest_model, DigestOnnxModel)
and file_path == widget.file
):
# Check if the model proto is different
if widget.digest_model.model_proto:
model_proto = onnx_utils.load_onnx(
file_path, load_external_data=False
)
# If they are equivalent, set the GUI to show the existing
# report and return
if model_proto == widget.digest_model.model_proto:
self.ui.tabWidget.setCurrentIndex(index)
return
# If they aren't equivalent, then the proto has been modified. In this case,
# we close the tab associated with the stale model, remove from the model list,
# then go through the standard process of adding it to the tabWidget. In the
# future, it may be slightly better to have an update tab function.
else:
self.closeTab(index)

try:

progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self)
QApplication.processEvents() # Process pending events

model = onnx_utils.load_onnx(file_path, load_external_data=False)
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model)
if not model_proto:
model_proto = onnx_utils.load_onnx(file_path, load_external_data=False)
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model_proto)
progress.step()

basename = os.path.splitext(os.path.basename(file_path))
Expand Down Expand Up @@ -918,6 +939,9 @@ def load_pytorch(self, file_path: str):
basename = os.path.splitext(os.path.basename(file_path))
model_name = basename[0]

# The current support for PyTorch includes exporting it to ONNX. In this case,
# an ingest window will pop up giving the user options to export. This window
# will block the main GUI until the ingest window is closed
self.pytorch_ingest = PyTorchIngest(file_path, model_name)
self.pytorch_ingest_window = PopupDialog(
self.pytorch_ingest, "PyTorch Ingest", self
Expand Down
21 changes: 13 additions & 8 deletions src/digest/model_class/digest_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from collections import OrderedDict
from typing import List, Tuple, Optional, Any, Union
from typing import List, Tuple, Optional, Union
import inspect
import onnx
import torch
Expand Down Expand Up @@ -37,7 +37,9 @@ def __init__(

# Input dictionary to contain the names and shapes
# required for exporting the ONNX model
self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict()
self.input_tensor_info: OrderedDict[
str, Tuple[torch.dtype, List[Union[str, int]]]
] = OrderedDict()

self.pytorch_model = torch.load(pytorch_file_path)

Expand All @@ -58,21 +60,24 @@ def save_yaml_report(self, file_path: str) -> None:
def save_text_report(self, file_path: str) -> None:
"""This will be done in the DigestOnnxModel"""

def generate_random_tensor(self, shape: List[Union[str, int]]):
def generate_random_tensor(self, dtype: torch.dtype, shape: List[Union[str, int]]):
static_shape = [dim if isinstance(dim, int) else 1 for dim in shape]
return torch.rand(static_shape)
if dtype in (torch.float16, torch.float32, torch.float64):
return torch.rand(static_shape, dtype=dtype)
else:
return torch.randint(0, 100, static_shape, dtype=dtype)

def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:

dummy_input_names: List[str] = list(self.input_tensor_info.keys())
dummy_inputs: List[torch.Tensor] = []

for shape in self.input_tensor_info.values():
dummy_inputs.append(self.generate_random_tensor(shape))
for dtype, shape in self.input_tensor_info.values():
dummy_inputs.append(self.generate_random_tensor(dtype, shape))

dynamic_axes = {
name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)}
for name, shape in self.input_tensor_info.items()
for name, (_, shape) in self.input_tensor_info.items()
}

try:
Expand All @@ -92,7 +97,7 @@ def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:

return onnx.load(output_onnx_path)

except (TypeError, RuntimeError) as err:
except (ValueError, TypeError, RuntimeError) as err:
print(f"Failed to export ONNX: {err}")
raise

Expand Down
2 changes: 1 addition & 1 deletion src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run(self):
model_proto = onnx_utils.load_onnx(file, False)
self.model_dict[file] = DigestOnnxModel(
model_proto,
onnx_filepath=file,
onnx_file_path=file,
model_name=model_name,
save_proto=False,
)
Expand Down
Loading

0 comments on commit cade740

Please sign in to comment.