Skip to content

Commit

Permalink
Merge branch 'fix/unnamed_parameters' into 'main'
Browse files Browse the repository at this point in the history
fix unnamed parameters

See merge request es/ai/hannah/hannah!344
  • Loading branch information
moreib committed Nov 23, 2023
2 parents f05474a + 476af49 commit 2b314b2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
11 changes: 9 additions & 2 deletions hannah/nas/parameters/parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def get_parameters(
while queue:
current = queue.pop(-1)
visited.append(current.id)
params[current.id] = current
if current.id is None:
name = current.name
else:
name = current.id
params[name] = current

if hasattr(current, "_PARAMETERS"):
for param in current._PARAMETERS.values():
Expand Down Expand Up @@ -246,7 +250,10 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False):

if k == key_list[-1] and isinstance(param, Expression):
if flatten:
hierarchical_params[param.id] = param
if param.id is None:
hierarchical_params[param.name] = param
else:
hierarchical_params[param.id] = param
else:
current_param_branch[index] = param
else:
Expand Down
43 changes: 33 additions & 10 deletions hannah/nas/test/test_functional_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ def test_functional_ops():
net = Relu()(net)


def test_functional_ops_chained():
kernel_size = CategoricalParameter([1, 3, 5], name='kernel_size')
out_channels = IntScalarParameter(min=4, max=64, name='out_channels')
stride = CategoricalParameter([1, 2], name='stride')

input = Tensor(name='input', shape=(1, 3, 32, 32), axis=('N', 'C', 'H', 'W'))
weight0 = Tensor(name='weight',
shape=(out_channels, 3, kernel_size, kernel_size),
axis=('O', 'I', 'kH', 'kW'))
net = Conv2d(stride=stride)(input, weight0)
net = Relu()(net)

weight1 = Tensor(name='weight',
shape=(out_channels, 3, kernel_size, kernel_size),
axis=('O', 'I', 'kH', 'kW'))

net = Conv2d(stride=stride.new())(net, weight1)
net = Relu()(net)

# print(net.parametrization(flatten=True))


def test_shape_propagation():
kernel_size = CategoricalParameter([1, 3, 5], name='kernel_size')
out_channels = IntScalarParameter(min=4, max=64, name='out_channels')
Expand Down Expand Up @@ -217,14 +239,15 @@ def dynamic_depth_block(input, depth):


if __name__ == '__main__':
test_functional_ops()
test_shape_propagation()
test_blocks()
test_operators()
test_multibranches()
test_scoping()
test_parametrization()
test_choice()
test_optional_op()
test_dynamic_depth()
# test_functional_ops()
test_functional_ops_chained()
# test_shape_propagation()
# test_blocks()
# test_operators()
# test_multibranches()
# test_scoping()
# test_parametrization()
# test_choice()
# test_optional_op()
# test_dynamic_depth()
# test_visualization()

0 comments on commit 2b314b2

Please sign in to comment.