Skip to content

Commit

Permalink
update pytests for interface changes and fix merge errors
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFSchulte committed Jan 10, 2025
1 parent 7e2fdf7 commit 10d77b6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
8 changes: 3 additions & 5 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,12 @@ def parse_pytorch_model(config, verbose=True):
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
new_node = traced_model.placeholder(name='const_' + str(i), type_expr=torch.Tensor, default_value=arg)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()
traced_model.lint()

for node in traced_model.nodes:
if node.op == 'call_module':
Expand Down Expand Up @@ -258,7 +256,7 @@ def parse_pytorch_model(config, verbose=True):
input_shapes = [output_shapes[str(node.args[0])]]
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
elif "getitem" in node.args[0].name:
for tmp_node in traced_model.graph.nodes:
for tmp_node in traced_model.nodes:
if tmp_node.name == node.args[0].name:
if "getitem" in tmp_node.args[0].name:
raise Exception('Nested getitem calles not resolved at the moment.')
Expand Down
28 changes: 15 additions & 13 deletions test/pytest/test_brevitas_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ def test_quantlinear(backend, io_type):
x = torch.tensor([1.0, 2.0, 3.0, 4.0])

pytorch_prediction = model(x).detach().numpy()
config = config_from_pytorch_model(model)
config = config_from_pytorch_model(model, input_shape=(None, 4))
output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}')

hls_model = convert_from_pytorch_model(
model,
(None, 4),
hls_config=config,
output_dir=output_dir,
backend=backend,
Expand All @@ -87,9 +86,13 @@ def test_quantconv1d(backend, io_type):
pytorch_prediction = model(x).detach().numpy()
if io_type == 'io_stream':
x = np.ascontiguousarray(x.permute(0, 2, 1))
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)
config = config_from_pytorch_model(
model, (None, n_in, size_in), channels_last_conversion="internal", transpose_outputs=False
)
else:
config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True)
config = config_from_pytorch_model(
model, (None, n_in, size_in), channels_last_conversion="full", transpose_outputs=True
)

output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}')

Expand Down Expand Up @@ -118,9 +121,7 @@ def test_quantconv1d(backend, io_type):
+ 1
) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html

hls_model = convert_from_pytorch_model(
model, (None, n_in, size_in), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
)
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
hls_model.compile()

if io_type == 'io_stream':
Expand All @@ -144,12 +145,15 @@ def test_quantconv2d(backend, io_type):
x = torch.randn(1, n_in, size_in_height, size_in_width)

pytorch_prediction = model(x).detach().numpy()
config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True)
if io_type == 'io_stream':
x = np.ascontiguousarray(x.permute(0, 2, 3, 1))
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)
config = config_from_pytorch_model(
model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="internal", transpose_outputs=False
)
else:
config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True)
config = config_from_pytorch_model(
model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="full", transpose_outputs=True
)

output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}')

Expand Down Expand Up @@ -191,7 +195,6 @@ def test_quantconv2d(backend, io_type):

hls_model = convert_from_pytorch_model(
model,
(None, n_in, size_in_height, size_in_width),
hls_config=config,
output_dir=output_dir,
backend=backend,
Expand Down Expand Up @@ -249,12 +252,11 @@ def test_pooling(pooling, backend):

pytorch_prediction = model(x).tensor.detach().numpy()

config = config_from_pytorch_model(model)
config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True)
output_dir = str(test_root_path / f'hls4mlprj_brevitas_{pooling.__name__}_{backend}')

hls_model = convert_from_pytorch_model(
model,
input_shape_forHLS,
hls_config=config,
output_dir=output_dir,
backend=backend,
Expand Down

0 comments on commit 10d77b6

Please sign in to comment.