From fabbab170af00203c60c1b5bea89768506830de9 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Wed, 18 Dec 2024 17:41:18 -0600 Subject: [PATCH 1/8] updated pytorch weight profiler --- hls4ml/model/profiling.py | 70 +++++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 84a83de23e..0f6f7d395f 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -78,7 +78,7 @@ def boxplot(data, fmt='longform'): medianprops = dict(linestyle='-', color='k') f, ax = plt.subplots(1, 1) - data.reverse() + # data.reverse() colors = sb.color_palette("Blues", len(data)) bp = ax.bxp(data, showfliers=False, vert=False, medianprops=medianprops) # add colored boxes @@ -381,13 +381,47 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'): def weights_torch(model, fmt='longform', plot='boxplot'): - suffix = ['w', 'b'] - if fmt == 'longform': - data = {'x': [], 'layer': [], 'weight': []} - elif fmt == 'summary': - data = [] - for layer in model.children(): - if isinstance(layer, torch.nn.Linear): + wt = WeightsTorch(model, fmt, plot) + return wt.get_weights() + + +class WeightsTorch: + def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='boxplot') -> None: + self.model = model + self.fmt = fmt + self.plot = plot + self.registered_layers = list() + self._find_layers(self.model, self.model.__class__.__name__) + + def _find_layers(self, model, module_name): + for name, module in model.named_children(): + if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)): + self._find_layers(module, module_name+"."+name) + elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): + if len(list(module.named_children())) != 0: + # custom nn.Module, continue search + self._find_layers(module, module_name+"."+name) + else: + self._register_layer(module_name+"."+name) + + def _is_registered(self, name : str) -> bool: + return name in self.registered_layers + + def _register_layer(self, name : str) -> None: + if self._is_registered(name) == False: + self.registered_layers.append(name) + + def _is_parameterized(self, module: torch.nn.Module) -> bool: + return any(p.requires_grad for p in module.parameters()) + + def _get_weights(self) -> pandas.DataFrame: + suffix = ['w', 'b'] + if self.fmt == 'longform': + data = {'x': [], 'layer': [], 'weight': []} + elif self.fmt == 'summary': + data = [] + for layer_name in self.registered_layers: + layer = self._get_layer(layer_name, self.model) name = layer.__class__.__name__ weights = list(layer.parameters()) for i, w in enumerate(weights): @@ -399,18 +433,26 @@ def weights_torch(model, fmt='longform', plot='boxplot'): if n == 0: print(f'Weights for {name} are only zeros, ignoring.') break - if fmt == 'longform': + if self.fmt == 'longform': data['x'].extend(w.tolist()) data['layer'].extend([name] * n) data['weight'].extend([label] * n) - elif fmt == 'summary': - data.append(array_to_summary(w, fmt=plot)) + elif self.fmt == 'summary': + data.append(array_to_summary(w, fmt=self.plot)) data[-1]['layer'] = name data[-1]['weight'] = label - if fmt == 'longform': - data = pandas.DataFrame(data) - return data + if self.fmt == 'longform': + data = pandas.DataFrame(data) + return data + + def get_weights(self) -> dict: + return self._get_weights() + + def _get_layer(self, layer_name : str, module : torch.nn.Module) -> torch.nn.Module: + for name in layer_name.split('.')[1:]: + module = getattr(module, name) + return module def activations_torch(model, X, fmt='longform', plot='boxplot'): From 855b1380656b3d7fe7da3806c5df37958ef9a133 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Wed, 18 Dec 2024 17:43:09 -0600 Subject: [PATCH 2/8] fix type --- hls4ml/model/profiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 0f6f7d395f..55035df02b 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -78,7 +78,7 @@ def boxplot(data, fmt='longform'): medianprops = dict(linestyle='-', color='k') f, ax = plt.subplots(1, 1) - # data.reverse() + data.reverse() colors = sb.color_palette("Blues", len(data)) bp = ax.bxp(data, showfliers=False, vert=False, medianprops=medianprops) # add colored boxes From a314242ac2863118bb76aabbb7e84064f4ca62b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:08:53 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/model/profiling.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 55035df02b..0375cda919 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -386,7 +386,7 @@ def weights_torch(model, fmt='longform', plot='boxplot'): class WeightsTorch: - def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='boxplot') -> None: + def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None: self.model = model self.fmt = fmt self.plot = plot @@ -396,18 +396,18 @@ def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='b def _find_layers(self, model, module_name): for name, module in model.named_children(): if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)): - self._find_layers(module, module_name+"."+name) - elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): - if len(list(module.named_children())) != 0: + self._find_layers(module, module_name + "." + name) + elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): + if len(list(module.named_children())) != 0: # custom nn.Module, continue search - self._find_layers(module, module_name+"."+name) + self._find_layers(module, module_name + "." + name) else: - self._register_layer(module_name+"."+name) + self._register_layer(module_name + "." + name) - def _is_registered(self, name : str) -> bool: + def _is_registered(self, name: str) -> bool: return name in self.registered_layers - def _register_layer(self, name : str) -> None: + def _register_layer(self, name: str) -> None: if self._is_registered(name) == False: self.registered_layers.append(name) @@ -448,8 +448,8 @@ def _get_weights(self) -> pandas.DataFrame: def get_weights(self) -> dict: return self._get_weights() - - def _get_layer(self, layer_name : str, module : torch.nn.Module) -> torch.nn.Module: + + def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module: for name in layer_name.split('.')[1:]: module = getattr(module, name) return module From f60acc5c367d4a7027a0897cc98cccb44fe28ab6 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Thu, 19 Dec 2024 08:54:16 -0600 Subject: [PATCH 4/8] update comparison to false --- hls4ml/model/profiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 0375cda919..ff2dd7437f 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -408,7 +408,7 @@ def _is_registered(self, name: str) -> bool: return name in self.registered_layers def _register_layer(self, name: str) -> None: - if self._is_registered(name) == False: + if self._is_registered(name) is False: self.registered_layers.append(name) def _is_parameterized(self, module: torch.nn.Module) -> bool: From d1a6b35af1a1fbaef56ade5a854becf69c0cf8a5 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Fri, 20 Dec 2024 15:50:12 -0600 Subject: [PATCH 5/8] fixed numerical condition for pytorch models and updates to type hints --- hls4ml/model/profiling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index ff2dd7437f..18a715d38d 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -414,7 +414,7 @@ def _register_layer(self, name: str) -> None: def _is_parameterized(self, module: torch.nn.Module) -> bool: return any(p.requires_grad for p in module.parameters()) - def _get_weights(self) -> pandas.DataFrame: + def _get_weights(self) -> pandas.DataFrame | list[dict]: suffix = ['w', 'b'] if self.fmt == 'longform': data = {'x': [], 'layer': [], 'weight': []} @@ -446,9 +446,12 @@ def _get_weights(self) -> pandas.DataFrame: data = pandas.DataFrame(data) return data - def get_weights(self) -> dict: + def get_weights(self) -> pandas.DataFrame | list[dict]: return self._get_weights() + def get_layers(self) -> list[str]: + return self.registered_layers + def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module: for name in layer_name.split('.')[1:]: module = getattr(module, name) @@ -526,11 +529,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): elif model_present: if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) - elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): + elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module): data = weights_torch(model, fmt='summary', plot=plot) if data is None: - print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled") + print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled") if hls_model_present and os.path.exists(tmp_output_dir): shutil.rmtree(tmp_output_dir) From ba83436295fde1a7df9b3d5556e093dd607419a6 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Tue, 14 Jan 2025 11:09:01 -0600 Subject: [PATCH 6/8] Create test_pytorch_profiler.py --- test/pytest/test_pytorch_profiler.py | 85 ++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 test/pytest/test_pytorch_profiler.py diff --git a/test/pytest/test_pytorch_profiler.py b/test/pytest/test_pytorch_profiler.py new file mode 100644 index 0000000000..8a994c6fe0 --- /dev/null +++ b/test/pytest/test_pytorch_profiler.py @@ -0,0 +1,85 @@ +import pytest + +import hls4ml + +try: + import torch + import torch.nn as nn + + __torch_profiling_enabled__ = True +except ImportError: + __torch_profiling_enabled__ = False + + +class SubClassModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + for idx, layer in enumerate(layers): + setattr(self, f'layer_{idx}', layer) + + +class ModuleListModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + self.layer = torch.nn.ModuleList(layers) + + +class NestedSequentialModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + self.model = torch.nn.Sequential(*layers) + + +def count_bars_in_figure(fig): + count = 0 + for ax in fig.get_axes(): + count += len(ax.patches) + return count + + +# Reusable parameter list +test_layers = [ + (4, [nn.Linear(10, 20), nn.Linear(20, 5)]), + (6, [nn.Linear(10, 20), nn.Linear(20, 5), nn.Conv1d(3, 16, kernel_size=3)]), + (6, [nn.Linear(15, 30), nn.Linear(30, 15), nn.Conv2d(1, 32, kernel_size=3)]), + (6, [nn.RNN(64, 128), nn.Linear(128, 10)]), + (6, [nn.LSTM(64, 128), nn.Linear(128, 10)]), + (6, [nn.GRU(64, 128), nn.Linear(128, 10)]), +] + + +@pytest.mark.parametrize("layers", test_layers) +def test_sequential_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = torch.nn.Sequential(*layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + wp.savefig('test.png') + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_subclass_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = SubClassModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_modulelist_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = ModuleListModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_nested_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = NestedSequentialModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count From c6a414fed411343735ed8d3a424b9de2939fcadb Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Tue, 14 Jan 2025 11:56:07 -0600 Subject: [PATCH 7/8] Update layer processing and add batchnorm testing --- hls4ml/model/profiling.py | 42 ++++++++++++++++++++++++++-- test/pytest/test_pytorch_profiler.py | 7 +++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 18a715d38d..f30088b51d 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -385,6 +385,45 @@ def weights_torch(model, fmt='longform', plot='boxplot'): return wt.get_weights() +def _torch_batchnorm(layer): + weights = list(layer.parameters()) + epsilon = layer.eps + + gamma = weights[0] + beta = weights[1] + if layer.track_running_stats: + mean = layer.running_mean + var = layer.running_var + else: + mean = torch.tensor(np.ones(20)) + var = torch.tensor(np.zeros(20)) + + scale = gamma / np.sqrt(var + epsilon) + bias = beta - gamma * mean / np.sqrt(var + epsilon) + + return [scale, bias], ['s', 'b'] + + +def _torch_layer(layer): + return list(layer.parameters()), ['w', 'b'] + + +def _torch_rnn(layer): + return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0'] + + +torch_process_layer_map = defaultdict( + lambda: _torch_layer, + { + 'BatchNorm1d': _torch_batchnorm, + 'BatchNorm2d': _torch_batchnorm, + 'RNN': _torch_rnn, + 'LSTM': _torch_rnn, + 'GRU': _torch_rnn, + }, +) + + class WeightsTorch: def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None: self.model = model @@ -415,7 +454,6 @@ def _is_parameterized(self, module: torch.nn.Module) -> bool: return any(p.requires_grad for p in module.parameters()) def _get_weights(self) -> pandas.DataFrame | list[dict]: - suffix = ['w', 'b'] if self.fmt == 'longform': data = {'x': [], 'layer': [], 'weight': []} elif self.fmt == 'summary': @@ -423,7 +461,7 @@ def _get_weights(self) -> pandas.DataFrame | list[dict]: for layer_name in self.registered_layers: layer = self._get_layer(layer_name, self.model) name = layer.__class__.__name__ - weights = list(layer.parameters()) + weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer) for i, w in enumerate(weights): label = f'{name}/{suffix[i]}' w = weights[i].detach().numpy() diff --git a/test/pytest/test_pytorch_profiler.py b/test/pytest/test_pytorch_profiler.py index 8a994c6fe0..c8b09fe35c 100644 --- a/test/pytest/test_pytorch_profiler.py +++ b/test/pytest/test_pytorch_profiler.py @@ -32,7 +32,7 @@ def __init__(self, layers) -> None: def count_bars_in_figure(fig): count = 0 - for ax in fig.get_axes(): + for ax in fig.get_axes(): count += len(ax.patches) return count @@ -40,6 +40,7 @@ def count_bars_in_figure(fig): # Reusable parameter list test_layers = [ (4, [nn.Linear(10, 20), nn.Linear(20, 5)]), + (3, [nn.Linear(10, 20), nn.BatchNorm1d(20)]), (6, [nn.Linear(10, 20), nn.Linear(20, 5), nn.Conv1d(3, 16, kernel_size=3)]), (6, [nn.Linear(15, 30), nn.Linear(30, 15), nn.Conv2d(1, 32, kernel_size=3)]), (6, [nn.RNN(64, 128), nn.Linear(128, 10)]), @@ -73,7 +74,7 @@ def test_modulelist_model(layers): param_count, layers = layers model = ModuleListModel(layers) wp, _, _, _ = hls4ml.model.profiling.numerical(model) - assert count_bars_in_figure(wp) == param_count + assert count_bars_in_figure(wp) == param_count @pytest.mark.parametrize("layers", test_layers) @@ -82,4 +83,4 @@ def test_nested_model(layers): param_count, layers = layers model = NestedSequentialModel(layers) wp, _, _, _ = hls4ml.model.profiling.numerical(model) - assert count_bars_in_figure(wp) == param_count + assert count_bars_in_figure(wp) == param_count From 4cfc2e5cfc86a82ad701283ece214657616b5466 Mon Sep 17 00:00:00 2001 From: Javier Campos Date: Tue, 14 Jan 2025 11:57:43 -0600 Subject: [PATCH 8/8] Remove typo --- test/pytest/test_pytorch_profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/pytest/test_pytorch_profiler.py b/test/pytest/test_pytorch_profiler.py index c8b09fe35c..746bfc9455 100644 --- a/test/pytest/test_pytorch_profiler.py +++ b/test/pytest/test_pytorch_profiler.py @@ -55,7 +55,6 @@ def test_sequential_model(layers): param_count, layers = layers model = torch.nn.Sequential(*layers) wp, _, _, _ = hls4ml.model.profiling.numerical(model) - wp.savefig('test.png') assert count_bars_in_figure(wp) == param_count