diff --git a/hannah/nas/parameters/parametrize.py b/hannah/nas/parameters/parametrize.py index 4a637cd6..ad99871e 100644 --- a/hannah/nas/parameters/parametrize.py +++ b/hannah/nas/parameters/parametrize.py @@ -193,7 +193,11 @@ def get_parameters( while queue: current = queue.pop(-1) visited.append(current.id) - params[current.id] = current + if current.id is None: + name = current.name + else: + name = current.id + params[name] = current if hasattr(current, "_PARAMETERS"): for param in current._PARAMETERS.values(): @@ -246,7 +250,10 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False): if k == key_list[-1] and isinstance(param, Expression): if flatten: - hierarchical_params[param.id] = param + if param.id is None: + hierarchical_params[param.name] = param + else: + hierarchical_params[param.id] = param else: current_param_branch[index] = param else: diff --git a/hannah/nas/test/test_functional_ops.py b/hannah/nas/test/test_functional_ops.py index b1384ef7..d94f5b60 100644 --- a/hannah/nas/test/test_functional_ops.py +++ b/hannah/nas/test/test_functional_ops.py @@ -78,6 +78,28 @@ def test_functional_ops(): net = Relu()(net) +def test_functional_ops_chained(): + kernel_size = CategoricalParameter([1, 3, 5], name='kernel_size') + out_channels = IntScalarParameter(min=4, max=64, name='out_channels') + stride = CategoricalParameter([1, 2], name='stride') + + input = Tensor(name='input', shape=(1, 3, 32, 32), axis=('N', 'C', 'H', 'W')) + weight0 = Tensor(name='weight', + shape=(out_channels, 3, kernel_size, kernel_size), + axis=('O', 'I', 'kH', 'kW')) + net = Conv2d(stride=stride)(input, weight0) + net = Relu()(net) + + weight1 = Tensor(name='weight', + shape=(out_channels, 3, kernel_size, kernel_size), + axis=('O', 'I', 'kH', 'kW')) + + net = Conv2d(stride=stride.new())(net, weight1) + net = Relu()(net) + + # print(net.parametrization(flatten=True)) + + def test_shape_propagation(): kernel_size = CategoricalParameter([1, 3, 5], name='kernel_size') out_channels = IntScalarParameter(min=4, max=64, name='out_channels') @@ -217,14 +239,15 @@ def dynamic_depth_block(input, depth): if __name__ == '__main__': - test_functional_ops() - test_shape_propagation() - test_blocks() - test_operators() - test_multibranches() - test_scoping() - test_parametrization() - test_choice() - test_optional_op() - test_dynamic_depth() + # test_functional_ops() + test_functional_ops_chained() + # test_shape_propagation() + # test_blocks() + # test_operators() + # test_multibranches() + # test_scoping() + # test_parametrization() + # test_choice() + # test_optional_op() + # test_dynamic_depth() # test_visualization()