Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parsing simple brevitas layers as part of pytorch models #1019

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e077161
playing with brevitas
JanFSchulte Nov 28, 2023
c946a03
add brevitas quantizer
JanFSchulte Dec 15, 2023
52bb225
latest brevitas developments
JanFSchulte Feb 9, 2024
b212f8e
Avoid Y2K22 Xilinx bug
fgias Mar 1, 2024
037c751
Merge pull request #2 from fgias/fgiasemi/y2k22-bug
JanFSchulte Mar 1, 2024
bf37179
state of the art brevitas parsing and add pytest
JanFSchulte May 20, 2024
df452f0
Merge branch 'brevitas' of https://github.com/JanFSchulte/hls4ml into…
JanFSchulte May 20, 2024
5316a48
fix some compilation errors
JanFSchulte May 21, 2024
6e47e9a
fix another trivial error in pytests
JanFSchulte May 21, 2024
f2201b0
Delete test_brevitas.py
simon71701 May 22, 2024
f64fab6
Delete test_brevitas_conv.py
simon71701 May 22, 2024
45954f9
fix dimensions in Conv2D pytest for brevitas parsing
JanFSchulte Jun 7, 2024
3eb759c
Merge pull request #3 from simon71701/brevitas
JanFSchulte Jun 7, 2024
73af4c1
trigger pre-commit
JanFSchulte Jun 7, 2024
44a6927
Merge branch 'brevitas' of https://github.com/JanFSchulte/hls4ml into…
JanFSchulte Jun 7, 2024
272c418
Merge branch 'main' into brevitas
JanFSchulte Jun 7, 2024
4f401ed
move quantizer to new file
JanFSchulte Jun 7, 2024
9c1740f
reduce diff and update access to tensors to latest version
JanFSchulte Jun 7, 2024
c769fef
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jun 7, 2024
0bb09f0
add brevitas to the requirements for tests
JanFSchulte Jul 9, 2024
e23b2ed
Merge branch 'brevitas' of https://github.com/JanFSchulte/hls4ml into…
JanFSchulte Jul 9, 2024
cda36b6
adjust required precision in brevitas pytests
JanFSchulte Jul 22, 2024
ef380ad
Add conv1d tests, fix output dir and tolerances
Jul 22, 2024
528b659
Merge branch 'main' into brevitas
JanFSchulte Jul 22, 2024
dffa379
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jul 22, 2024
399613e
Test QuantMaxPool and ignore QuantDropout
Jul 25, 2024
73590b7
Merge branch 'main' into brevitas
JanFSchulte Jan 10, 2025
d13bf52
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jan 10, 2025
22dd2cb
merge with master
JanFSchulte Jan 10, 2025
1f17845
restore accidental change
JanFSchulte Jan 10, 2025
787c4e1
Merge branch 'main' into brevitas
JanFSchulte Jan 10, 2025
7e2fdf7
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jan 10, 2025
10d77b6
update pytests for interface changes and fix merge errors
JanFSchulte Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions hls4ml/converters/pytorch/convolution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler
from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format
from hls4ml.model.quantizers import BrevitasQuantizer
from hls4ml.model.types import FixedPrecisionType


@pytorch_handler('Conv1d')
@pytorch_handler('Conv1d', 'QuantConv1d')
def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Conv1d' in operation

Expand All @@ -13,12 +15,32 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c
layer['class_name'] = 'Conv1D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
if "Quant" in operation:
if class_object.is_weight_quant_enabled:
width = int(class_object.quant_weight().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale))
layer['weight_data'] = class_object.quant_weight().detach().value.numpy()
layer['weight_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['weight_data'] = class_object.weight.data.numpy()

if class_object.is_bias_quant_enabled:
width = int(class_object.quant_bias().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale))
layer['bias_data'] = class_object.quant_bias().detach().value.numpy()
layer['bias_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None
# Input info
(layer['in_width'], layer['n_chan']) = parse_data_format(
input_shapes[0], 'channels_first'
Expand Down Expand Up @@ -47,7 +69,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c
return layer, output_shape


@pytorch_handler('Conv2d')
@pytorch_handler('Conv2d', 'QuantConv2d')
def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Conv2d' in operation

Expand All @@ -58,11 +80,32 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c
layer['class_name'] = 'Conv2D'
layer['data_format'] = 'channels_first' # Pytorch default (can't change)

layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
if "Quant" in operation:
if class_object.is_weight_quant_enabled:
width = int(class_object.quant_weight().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale))
layer['weight_data'] = class_object.quant_weight().detach().value.numpy()
layer['weight_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['weight_data'] = class_object.weight.data.numpy()

if class_object.is_bias_quant_enabled:
width = int(class_object.quant_bias().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale))
layer['bias_data'] = class_object.quant_bias().detach().value.numpy()
layer['bias_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None
layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

# Input info
(layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(
Expand Down
56 changes: 53 additions & 3 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler
from hls4ml.model.quantizers import BrevitasQuantizer
from hls4ml.model.types import FixedPrecisionType


@pytorch_handler('Constant')
Expand All @@ -20,7 +22,7 @@ def parse_constant_layer(operation, layer_name, node):
return layer, output_shape


@pytorch_handler('Linear')
@pytorch_handler('Linear', 'QuantLinear')
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Linear' in operation

Expand All @@ -36,6 +38,33 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c
else:
layer['bias_data'] = None

if "Quant" in operation:
if class_object.is_weight_quant_enabled:
width = int(class_object.quant_weight().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale))
layer['weight_data'] = class_object.quant_weight().detach().value.numpy()
layer['weight_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['weight_data'] = class_object.weight.data.numpy()

if class_object.is_bias_quant_enabled:
width = int(class_object.quant_bias().bit_width)
ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale))
layer['bias_data'] = class_object.quant_bias().detach().value.numpy()
layer['bias_quantizer'] = BrevitasQuantizer(
width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)
)
else:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['weight_data'] = class_object.weight.data.numpy()
if class_object.bias is not None:
layer['bias_data'] = class_object.bias.data.numpy()
else:
layer['bias_data'] = None

if class_object is not None:
layer['n_in'] = class_object.in_features
layer['n_out'] = class_object.out_features
Expand All @@ -54,7 +83,19 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c
return layer, output_shape


activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh']
activation_layers = [
'Softmax',
'ReLU',
'LeakyReLU',
'Threshold',
'ELU',
'PReLU',
'Sigmoid',
'Tanh',
'QuantReLU',
'QuantSigmoid',
'QuantTanh',
]


@pytorch_handler(*activation_layers)
Expand All @@ -66,6 +107,15 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
layer['name'] = layer_name
layer['inputs'] = input_names

if "Quant" in operation:
layer['class_name'] = operation.split('Quant')[-1]
layer['activation'] = layer['class_name']
bit_width = class_object.quant_act_bit_width()
ap_fixed_params = convert_uaq_to_apfixed(bit_width, class_object.quant_act_scale())
layer['activation_quantizer'] = BrevitasQuantizer(
bit_width, FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False)
)

if node.op == 'call_module':
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
layer['class_name'] = 'Activation'
Expand Down
13 changes: 10 additions & 3 deletions hls4ml/converters/pytorch/pooling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format

pooling_layers = ['MaxPool1d', 'MaxPool2d', 'AvgPool1d', 'AvgPool2d']
pooling_layers = [
'MaxPool1d',
'MaxPool2d',
'AvgPool1d',
'AvgPool2d',
'QuantMaxPool1d',
'QuantMaxPool2d',
] # TODO add support for special quantized average pool layers


@pytorch_handler(*pooling_layers)
Expand All @@ -10,9 +17,9 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node,

layer = {}

if operation == 'MaxPool1d':
if 'MaxPool1d' in operation:
layer['class_name'] = 'MaxPooling1D'
if operation == 'MaxPool2d':
if 'MaxPool2d' in operation:
layer['class_name'] = 'MaxPooling2D'
if operation == 'AvgPool1d':
layer['class_name'] = 'AveragePooling1D'
Expand Down
50 changes: 40 additions & 10 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
import math

import numpy as np
import torch

from hls4ml.model import ModelGraph


class CustomFXTracer(torch.fx.Tracer):

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX
"""
return (
m.__module__.startswith("torch.nn")
or m.__module__.startswith("torch.ao.nn")
or m.__module__.startswith("brevitas.nn")
) and not isinstance(m, torch.nn.Sequential)


class PyTorchModelReader:
"""
PyTorch reader to extract weights data.
Expand Down Expand Up @@ -59,6 +74,23 @@ def get_weights_data(data_reader, layer_name, var_name):
return (*data,)


def convert_uaq_to_apfixed(bitwidth, scale_factor):
"""
parameters:
bitwidth: int
scale_factor: float
zero_point: float

return:
int_bitwidth: int
fract_bitwidth: int
"""
fract_bitwidth = -math.log2(scale_factor)
int_bitwidth = bitwidth - fract_bitwidth

return (fract_bitwidth, int_bitwidth)


# ----------------------Layer handling--------------------- #
layer_handlers = {}

Expand Down Expand Up @@ -135,11 +167,11 @@ def parse_pytorch_model(config, verbose=True):
# dict of layer objects in non-traced form for access lateron
children = {c[0]: c[1] for c in model.named_children()}
# use symbolic_trace to get a full graph of the model
from torch.fx import symbolic_trace

traced_model = symbolic_trace(model)
tracer = CustomFXTracer()
traced_model = tracer.trace(model)
# Define layers to skip for conversion to HLS
skip_layers = ['Dropout', 'Sequential']
skip_layers = ['Dropout', 'QuantDropout', 'Sequential']

# All supported layers
supported_layers = get_supported_pytorch_layers() + skip_layers
Expand All @@ -163,21 +195,19 @@ def parse_pytorch_model(config, verbose=True):
# check for constant nodes
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
i = 0 # count number of consts and use it in the name
for node in traced_model.graph.nodes:
for node in traced_model.nodes:
if node.name.split('_')[0] in merge_layers:
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
new_node = traced_model.placeholder(name='const_' + str(i), type_expr=torch.Tensor, default_value=arg)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()
traced_model.lint()

for node in traced_model.graph.nodes:
for node in traced_model.nodes:
if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
# where x is an integer numbering the elements of the Sequential
Expand Down Expand Up @@ -226,7 +256,7 @@ def parse_pytorch_model(config, verbose=True):
input_shapes = [output_shapes[str(node.args[0])]]
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
elif "getitem" in node.args[0].name:
for tmp_node in traced_model.graph.nodes:
for tmp_node in traced_model.nodes:
if tmp_node.name == node.args[0].name:
if "getitem" in tmp_node.args[0].name:
raise Exception('Nested getitem calles not resolved at the moment.')
Expand Down
16 changes: 16 additions & 0 deletions hls4ml/model/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ def __call__(self, data):
return y


class BrevitasQuantizer(Quantizer):
"""Wrapper around brevitas quantizers. Since we can get the already quantized tensors
directly from the brevitas QuantTensor objects, nothing needs to be done

Args:
bits: bitwidth of the quantized tensor
hls_type: hls_type of the quantized tensor
"""

def __init__(self, bits, hls_type):
super().__init__(bits, hls_type)

def __call__(self, data):
return data


class QuantNodeQuantizer(Quantizer):
"""
This implements a quantizer for a FixedPrecisionType with width==integer
Expand Down
1 change: 1 addition & 0 deletions hls4ml/templates/vivado/build_prj.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ if {$opt(validation)} {
if {$opt(export)} {
puts "***** EXPORT IP *****"
set time_start [clock clicks -milliseconds]

export_design -format ip_catalog -version $version
set time_end [clock clicks -milliseconds]
report_time "EXPORT IP" $time_start $time_end
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ sr =
sympy
testing =
HGQ~=0.2.0
brevitas
pytest
pytest-cov
pytest-randomly
Expand Down
Loading
Loading