Skip to content

Commit

Permalink
Code to validate report file
Browse files Browse the repository at this point in the history
- robust pixmap handling
- better pixmap quality
- copy png instead of grab()
- scale loading gif
- multimodel report support
- recompiled gui with pyside6.8.1
  • Loading branch information
Philip Colangelo committed Dec 11, 2024
1 parent 112403b commit c86564f
Show file tree
Hide file tree
Showing 23 changed files with 680 additions and 634 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.0.0",
version="1.1.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
14 changes: 12 additions & 2 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,23 @@ class WarnDialog(QDialog):

def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
self.setWindowTitle("Warning Message")

self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))

self.setWindowTitle("Warning Message")
self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)

self.setWindowModality(Qt.WindowModality.WindowModal)

layout = QVBoxLayout()

# Application Version
layout.addWidget(QLabel("<b>Something went wrong</b>"))
layout.addWidget(QLabel("<b>Warning</b>"))
layout.addWidget(QLabel(warning_message))

ok_button = QPushButton("OK")
ok_button.clicked.connect(self.accept) # Close dialog when clicked
layout.addWidget(ok_button)

self.setLayout(layout)
2 changes: 1 addition & 1 deletion src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)

self.plot_widget = pg.PlotWidget()
self.plot_widget.setMaximumHeight(150)
self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
Expand Down
61 changes: 38 additions & 23 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import sys
import shutil
import argparse
from datetime import datetime
from typing import Dict, Tuple, Optional, Union
Expand Down Expand Up @@ -33,7 +34,7 @@
QMenu,
)
from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont
from PySide6.QtCore import Qt
from PySide6.QtCore import Qt, QSize

from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
from digest.thread import StatsThread, SimilarityThread
Expand Down Expand Up @@ -309,9 +310,9 @@ def update_cards(
digest_model: DigestModel,
unique_id: str,
):
self.digest_models[unique_id].model_flops = digest_model.model_flops
self.digest_models[unique_id].flops = digest_model.flops
self.digest_models[unique_id].node_type_flops = digest_model.node_type_flops
self.digest_models[unique_id].model_parameters = digest_model.model_parameters
self.digest_models[unique_id].parameters = digest_model.parameters
self.digest_models[unique_id].node_type_parameters = (
digest_model.node_type_parameters
)
Expand All @@ -326,10 +327,10 @@ def update_cards(
isinstance(widget, modelSummary)
and widget.digest_model.unique_id == unique_id
):
if digest_model.model_flops is None:
if digest_model.flops is None:
flops_str = "--"
else:
flops_str = format(digest_model.model_flops, ",")
flops_str = format(digest_model.flops, ",")

# Set up the pie chart
pie_chart_labels, pie_chart_data = zip(
Expand Down Expand Up @@ -390,10 +391,20 @@ def update_similarity_widget(
break

if completed_successfully and isinstance(widget, modelSummary) and png_filepath:
widget_width = widget.ui.similarityWidget.width()
widget.ui.similarityImg.setPixmap(
QPixmap(png_filepath).scaledToWidth(widget_width)
widget.load_gif.stop()
widget.ui.similarityImg.clear()
widget_width = widget.ui.similarityImg.width()

pixmap = QPixmap(png_filepath)
aspect_ratio = pixmap.width() / pixmap.height()
target_height = int(widget_width / aspect_ratio)
pixmap_scaled = pixmap.scaled(
QSize(widget_width, target_height),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)

widget.ui.similarityImg.setPixmap(pixmap_scaled)
widget.ui.similarityImg.setText("")
widget.ui.similarityImg.setCursor(Qt.CursorShape.PointingHandCursor)

Expand Down Expand Up @@ -429,7 +440,8 @@ def update_similarity_widget(
widget.ui.similarityCorrelation.setText(text)
elif isinstance(widget, modelSummary):
# Remove animation and set text to failing message
widget.ui.similarityImg.setMovie(QMovie())
widget.load_gif.stop()
widget.ui.similarityImg.clear()
widget.ui.similarityImg.setText("Failed to perform similarity analysis")
else:
print(
Expand Down Expand Up @@ -666,10 +678,6 @@ def load_onnx(self, filepath: str):
self.ui.singleModelWidget.show()
progress.step()

movie = QMovie(":/assets/gifs/load.gif")
model_summary.ui.similarityImg.setMovie(movie)
movie.start()

# Start similarity Analysis
# Note: Should only be started after the model tab has been created
png_tmp_path = os.path.join(self.temp_dir.name, model_id)
Expand Down Expand Up @@ -716,6 +724,16 @@ def load_report(self, filepath: str):

digest_model = DigestReportModel(filepath)

if not digest_model.is_valid:
progress.close()
invalid_yaml_dialog = StatusDialog(
title="Warning",
status_message=f"YAML file {filepath} is not a valid digest report",
)
invalid_yaml_dialog.show()

return

model_id = digest_model.unique_id

# There is no sense in offering to save the report
Expand All @@ -739,9 +757,7 @@ def load_report(self, filepath: str):
model_summary.ui.modelFilename.setText(filepath)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

model_summary.ui.parameters.setText(
format(digest_model.model_parameters, ",")
)
model_summary.ui.parameters.setText(format(digest_model.parameters, ","))

node_type_counts = digest_model.node_type_counts
if len(node_type_counts) < 15:
Expand All @@ -751,7 +767,6 @@ def load_report(self, filepath: str):

model_summary.ui.opHistogramChart.bar_spacing = bar_spacing
model_summary.ui.opHistogramChart.set_data(node_type_counts)

model_summary.ui.nodes.setText(str(sum(node_type_counts.values())))

progress.step()
Expand Down Expand Up @@ -962,13 +977,13 @@ def save_reports(self):
)
digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Save the similarity image
similarity_png = self.model_similarity_report[
# Save (copy) the similarity image
png_file_path = self.model_similarity_thread[
digest_model.unique_id
].enlarged_image_label.grab()
similarity_png.save(
os.path.join(save_directory, f"{model_name}_heatmap.png"), "PNG"
)
].png_filepath
png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png")
if png_file_path and os.path.exists(png_file_path):
shutil.copy(png_file_path, png_save_path)

# Save the text report
txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt")
Expand Down
10 changes: 5 additions & 5 deletions src/digest/model_class/digest_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def __init__(self, *args, **kwargs):


class DigestModel(ABC):
def __init__(self, filepath: str, model_name: str):
def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes):
# Public members exposed to the API
self.unique_id: str = str(uuid4())
self.filepath: Optional[str] = filepath
self.model_name: str = model_name
self.model_type: Optional[SupportedModelTypes] = None
self.model_type: SupportedModelTypes = model_type
self.node_type_counts: NodeTypeCounts = NodeTypeCounts()
self.model_flops: Optional[int] = None
self.model_parameters: int = 0
self.flops: Optional[int] = None
self.parameters: int = 0
self.node_type_flops: Dict[str, int] = {}
self.node_type_parameters: Dict[str, int] = {}
self.node_data = NodeData()
Expand All @@ -118,7 +118,7 @@ def get_node_shape_counts(self) -> NodeShapeCounts:
return tensor_shape_counter

@abstractmethod
def parse_model_nodes(self, *args) -> None:
def parse_model_nodes(self, *args, **kwargs) -> None:
pass

@abstractmethod
Expand Down
Loading

0 comments on commit c86564f

Please sign in to comment.