From 426d24fdee2c37098a37db26c03de71ac256e0c8 Mon Sep 17 00:00:00 2001 From: Philip Colangelo Date: Wed, 11 Dec 2024 18:00:52 -0500 Subject: [PATCH] [fix] analysis script --- examples/analysis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/analysis.py b/examples/analysis.py index 0cd6344..da89068 100644 --- a/examples/analysis.py +++ b/examples/analysis.py @@ -84,8 +84,8 @@ def main(onnx_files: str, output_dir: str): global_model_data[model_name] = { "opset": digest_model.opset, - "parameters": digest_model.model_parameters, - "flops": digest_model.model_flops, + "parameters": digest_model.parameters, + "flops": digest_model.flops, } # Model summary text report @@ -108,7 +108,7 @@ def main(onnx_files: str, output_dir: str): digest_model.save_node_type_counts_csv_report(node_type_filepath) # Update global data structure for node type counter - global_node_type_counter.update(digest_model.get_node_type_counts()) + global_node_type_counter.update(digest_model.node_type_counts) # Save csv containing node shape counts per op_type node_shape_filepath = os.path.join( @@ -122,10 +122,8 @@ def main(onnx_files: str, output_dir: str): if len(onnx_file_list) > 1: global_filepath = os.path.join(output_dir, "global_node_type_counts.csv") - global_node_type_counter = NodeTypeCounts( - global_node_type_counter.most_common() - ) - save_node_type_counts_csv_report(global_node_type_counter, global_filepath) + global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common()) + save_node_type_counts_csv_report(global_node_type_counts, global_filepath) global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv") save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)