Skip to content

Commit

Permalink
fixed numerical condition for pytorch models and updates to type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jicampos committed Dec 20, 2024
1 parent f60acc5 commit d1a6b35
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _register_layer(self, name: str) -> None:
def _is_parameterized(self, module: torch.nn.Module) -> bool:
return any(p.requires_grad for p in module.parameters())

def _get_weights(self) -> pandas.DataFrame:
def _get_weights(self) -> pandas.DataFrame | list[dict]:
suffix = ['w', 'b']
if self.fmt == 'longform':
data = {'x': [], 'layer': [], 'weight': []}
Expand Down Expand Up @@ -446,9 +446,12 @@ def _get_weights(self) -> pandas.DataFrame:
data = pandas.DataFrame(data)
return data

def get_weights(self) -> dict:
def get_weights(self) -> pandas.DataFrame | list[dict]:
return self._get_weights()

def get_layers(self) -> list[str]:
return self.registered_layers

def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
for name in layer_name.split('.')[1:]:
module = getattr(module, name)
Expand Down Expand Up @@ -526,11 +529,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
elif model_present:
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
data = weights_keras(model, fmt='summary', plot=plot)
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module):
data = weights_torch(model, fmt='summary', plot=plot)

if data is None:
print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled")
print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled")

if hls_model_present and os.path.exists(tmp_output_dir):
shutil.rmtree(tmp_output_dir)
Expand Down

0 comments on commit d1a6b35

Please sign in to comment.