diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 2c05b7501..57c42f401 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,6 +1,25 @@ +import numpy as np + from hls4ml.converters.pytorch_to_hls import pytorch_handler +@pytorch_handler('Constant') +def parse_constant_layer(operation, layer_name, node): + assert 'Constant' in operation + + layer = {} + layer['inputs'] = [] + + layer['class_name'] = 'Constant' + layer['name'] = layer_name + + constant = np.array(node._args) + layer['value'] = constant + output_shape = constant.shape + + return layer, output_shape + + @pytorch_handler('Linear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 79ca1fa5c..871026bc4 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,3 +1,4 @@ +import numpy as np import torch from hls4ml.model import ModelGraph @@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True): n_inputs = 0 + # check for constant nodes + merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax'] + i = 0 # count number of consts and use it in the name + for node in traced_model.graph.nodes: + if node.name.split('_')[0] in merge_layers: + for arg in node.args: + if np.isscalar(arg): + # add an input node with the constant value + new_node = traced_model.graph.placeholder( + name='const_' + str(i), type_expr=torch.Tensor, default_value=arg + ) + node.prepend(new_node) + node.update_arg(1, new_node) + i += 1 + + traced_model.graph.lint() + for node in traced_model.graph.nodes: if node.op == 'call_module': # modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x', @@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True): input_layer = {} input_layer['name'] = node.name - input_layer['class_name'] = 'InputLayer' - input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) - layer_list.insert(n_inputs, input_layer) - output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) - input_layers.append(input_layer['name']) - n_inputs += 1 + if 'const' in node.name: + pytorch_class = 'Constant' + layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node) + + layer_list.append(layer) + + assert output_shape is not None + output_shapes[layer['name']] = output_shape + + else: + + input_layer['class_name'] = 'InputLayer' + input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) + layer_list.insert(n_inputs, input_layer) + + output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) + + input_layers.append(input_layer['name']) + n_inputs += 1 layer_counter += 1