Skip to content

Commit

Permalink
[fix] analysis script
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Dec 11, 2024
1 parent c86564f commit 426d24f
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def main(onnx_files: str, output_dir: str):

global_model_data[model_name] = {
"opset": digest_model.opset,
"parameters": digest_model.model_parameters,
"flops": digest_model.model_flops,
"parameters": digest_model.parameters,
"flops": digest_model.flops,
}

# Model summary text report
Expand All @@ -108,7 +108,7 @@ def main(onnx_files: str, output_dir: str):
digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Update global data structure for node type counter
global_node_type_counter.update(digest_model.get_node_type_counts())
global_node_type_counter.update(digest_model.node_type_counts)

# Save csv containing node shape counts per op_type
node_shape_filepath = os.path.join(
Expand All @@ -122,10 +122,8 @@ def main(onnx_files: str, output_dir: str):

if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
global_node_type_counter = NodeTypeCounts(
global_node_type_counter.most_common()
)
save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
save_node_type_counts_csv_report(global_node_type_counts, global_filepath)

global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
Expand Down

0 comments on commit 426d24f

Please sign in to comment.