Skip to content

Commit

Permalink
Merge pull request #1142 from fastmachinelearning/qonnx_warnings
Browse files Browse the repository at this point in the history
Qonnx warnings
  • Loading branch information
JanFSchulte authored Dec 5, 2024
2 parents ce7f1f1 + 915d2e1 commit c8e1857
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
4 changes: 2 additions & 2 deletions hls4ml/model/optimizer/passes/batchnorm_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class FuseConsecutiveBatchNormalization(OptimizerPass):
"""

def match(self, node):
prev_node = node.get_input_node(node.inputs[0])
prev_node = node.get_input_node()
basic_match = (
isinstance(node, BatchNormalization)
and isinstance(prev_node, BatchNormalization)
Expand Down Expand Up @@ -194,7 +194,7 @@ def match(self, node):
return False

def transform(self, model, node):
prev_node = node.get_input_node(node.inputs[0])
prev_node = node.get_input_node()

prev_map = prev_node.get_output_use_map()
if len(prev_map[prev_node.outputs[0]]) > 1:
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/model/optimizer/passes/bn_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FuseBatchNormalization(OptimizerPass):
"""

def match(self, node):
prev_node = node.get_input_node(node.inputs[0])
prev_node = node.get_input_node()
basic_match = (
isinstance(node, BatchNormalization)
and isinstance(prev_node, (Dense, Conv1D, Conv2D))
Expand Down
30 changes: 30 additions & 0 deletions hls4ml/model/optimizer/passes/move_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
'''

import warnings

import numpy as np

from hls4ml.model.layers import ApplyAlpha, Constant, Conv, MatMul, Merge
Expand Down Expand Up @@ -85,6 +87,9 @@ def transform(self, model, node):
can_propagate = False

if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down MatMul node; model probably not suppored.', stacklevel=1
)
return False

model.remove_node(apply_alpha)
Expand Down Expand Up @@ -124,6 +129,9 @@ def transform(self, model, node):
try:
bias = bias0 + bias1
except ValueError:
warnings.warn(
'Failed to propagate quantization scales down Add node; model probably not suppored.', stacklevel=1
)
return False

model.remove_node(in0)
Expand Down Expand Up @@ -169,6 +177,7 @@ def transform(self, model, node):
model.insert_node(new_node)
return True
else:
warnings.warn('Failed to propagate quantization bias down Add node; model probably not suppored.', stacklevel=1)
return False


Expand Down Expand Up @@ -243,6 +252,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand Down Expand Up @@ -287,6 +299,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand All @@ -308,6 +323,9 @@ def transform(self, model, node):
can_propagate = False

if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand Down Expand Up @@ -367,6 +385,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand All @@ -388,6 +409,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand All @@ -412,6 +436,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand Down Expand Up @@ -445,6 +472,9 @@ def transform(self, model, node):
except ValueError:
can_propagate = False
if not can_propagate:
warnings.warn(
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
)
return False

# to remove warning, since these get set again
Expand Down

0 comments on commit c8e1857

Please sign in to comment.