Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Torch profiler #1156

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 60 additions & 15 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,47 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):


def weights_torch(model, fmt='longform', plot='boxplot'):
suffix = ['w', 'b']
if fmt == 'longform':
data = {'x': [], 'layer': [], 'weight': []}
elif fmt == 'summary':
data = []
for layer in model.children():
if isinstance(layer, torch.nn.Linear):
wt = WeightsTorch(model, fmt, plot)
return wt.get_weights()


class WeightsTorch:
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
self.model = model
self.fmt = fmt
self.plot = plot
self.registered_layers = list()
self._find_layers(self.model, self.model.__class__.__name__)

def _find_layers(self, model, module_name):
for name, module in model.named_children():
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
self._find_layers(module, module_name + "." + name)
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
if len(list(module.named_children())) != 0:
# custom nn.Module, continue search
self._find_layers(module, module_name + "." + name)
else:
self._register_layer(module_name + "." + name)

def _is_registered(self, name: str) -> bool:
return name in self.registered_layers

def _register_layer(self, name: str) -> None:
if self._is_registered(name) is False:
self.registered_layers.append(name)

def _is_parameterized(self, module: torch.nn.Module) -> bool:
return any(p.requires_grad for p in module.parameters())

def _get_weights(self) -> pandas.DataFrame | list[dict]:
suffix = ['w', 'b']
if self.fmt == 'longform':
data = {'x': [], 'layer': [], 'weight': []}
elif self.fmt == 'summary':
data = []
for layer_name in self.registered_layers:
layer = self._get_layer(layer_name, self.model)
name = layer.__class__.__name__
weights = list(layer.parameters())
for i, w in enumerate(weights):
Expand All @@ -399,18 +433,29 @@ def weights_torch(model, fmt='longform', plot='boxplot'):
if n == 0:
print(f'Weights for {name} are only zeros, ignoring.')
break
if fmt == 'longform':
if self.fmt == 'longform':
data['x'].extend(w.tolist())
data['layer'].extend([name] * n)
data['weight'].extend([label] * n)
elif fmt == 'summary':
data.append(array_to_summary(w, fmt=plot))
elif self.fmt == 'summary':
data.append(array_to_summary(w, fmt=self.plot))
data[-1]['layer'] = name
data[-1]['weight'] = label

if fmt == 'longform':
data = pandas.DataFrame(data)
return data
if self.fmt == 'longform':
data = pandas.DataFrame(data)
return data

def get_weights(self) -> pandas.DataFrame | list[dict]:
return self._get_weights()

def get_layers(self) -> list[str]:
return self.registered_layers

def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
for name in layer_name.split('.')[1:]:
module = getattr(module, name)
return module


def activations_torch(model, X, fmt='longform', plot='boxplot'):
Expand Down Expand Up @@ -484,11 +529,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
elif model_present:
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
data = weights_keras(model, fmt='summary', plot=plot)
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module):
data = weights_torch(model, fmt='summary', plot=plot)

if data is None:
print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled")
print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled")

if hls_model_present and os.path.exists(tmp_output_dir):
shutil.rmtree(tmp_output_dir)
Expand Down
Loading