diff --git a/environment.yml b/environment.yml deleted file mode 100644 index e22f2bf8..00000000 --- a/environment.yml +++ /dev/null @@ -1,33 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -name: hannah - -channels: - - pytorch - - conda-forge - - defaults - -dependencies: - - python=3.9 - - cudatoolkit=10.2 - - libsndfile - - poetry=1.* - - nccl - - setuptools=59.5.0 - - gxx_linux-64=7 diff --git a/experiments/kws/README.md b/experiments/kws/README.md new file mode 100644 index 00000000..dd1db5a0 --- /dev/null +++ b/experiments/kws/README.md @@ -0,0 +1,32 @@ +# Example of using new FlexKI hannah search spaces for new models + +The new search spaces for hannah are supposed to allow an easy expression of very large variability of neural network searches. + +This is a very simple example of using a it for searching a 1D-Convolutional Network + +There are 3 options to run the training in this folder. + +### Normal Neural Network Training + +```bash +hannah-train +``` + +this will train a single neural network + + +### NAS on a flexible search space definition + +```bash +hannah-train +experiment=ae_nas +``` + +This will use the flexible search space definition. Running for a direct nas on an aging evolution based optmizer. + +### Legacy NAS using fixed search spaces + +```bash +hannah-train +experiment=legacy_nas +``` + +This uses the legacy/orginal Hannah search spaces as defined in the Paper. \ No newline at end of file diff --git a/experiments/kws/config.yaml b/experiments/kws/config.yaml index 93f84d89..c67a2e57 100644 --- a/experiments/kws/config.yaml +++ b/experiments/kws/config.yaml @@ -19,11 +19,15 @@ defaults: - - base_config - - _self_ + - base_config # Base configuration uses a single neural network training and kws dataset + - _self_ # This is a special value that specifies that values defined in this file take precedence over values from the other files -module: - num_workers: 8 +module: # The module encapsulate the target task for neural network training in this case we use the default task which is classification on 1D signals + num_workers: 32 # Number of workers gives the number of parallel processes used to load data + batch_size: 1024 -trainer: - max_epochs: 30 +trainer: # Trainer arguments set hyperparameters for all trainings + max_epochs: 30 + +dataset: + data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/} # Set the location for dataset files in this case we wille use the value of the environment variable HANNAH_DATA_FOLDER or the folder ../../datasets/ relative to the location of the directory where hannah-train is run, usually the folder containing this file \ No newline at end of file diff --git a/experiments/kws/experiment/ae_nas.yaml b/experiments/kws/experiment/ae_nas.yaml new file mode 100644 index 00000000..b988bfa2 --- /dev/null +++ b/experiments/kws/experiment/ae_nas.yaml @@ -0,0 +1,16 @@ +# @package _global_ +# The preciding line specifies that the following configuration changes global configuration settings instead of setting in the experiment namespace + +defaults: + - override /nas: aging_evolution_nas + - override /model: 1d_space + +experiment_id: ae_nas # The experiment id is used to identify the experiment it especially defines the subfolder under /trained_models where the results will be saved + +nas: + predictor: null + bounds: + val_error: 0.08 + total_macs: 250000 + budget: 250 + input_shape: [40,101] \ No newline at end of file diff --git a/experiments/kws/experiment/legacy_nas.yaml b/experiments/kws/experiment/legacy_nas.yaml new file mode 100644 index 00000000..12d11f82 --- /dev/null +++ b/experiments/kws/experiment/legacy_nas.yaml @@ -0,0 +1,37 @@ +# @package _global_ +# The preciding line specifies that the following configuration changes global configuration settings instead of setting in the experiment namespace + +defaults: + - override /nas: aging_evolution_nas_legacy + +experiment_id: legacy_nas # The experiment id is used to identify the experiment it especially defines the subfolder under /trained_models where the results will be saved + + +nas: + parametrization: + model: + qconfig: + config: + bw_f: [4,8] + bw_w: [2,4,8] + conv: + min: 1 + max: 2 + + choices: + - target: forward + stride: [1,2] + blocks: + min: 1 + max: 4 + choices: + - target: conv1d + kernel_size: [3,5,7] + act: true + norm: true + out_channels: [8,16,32,64] + + bounds: + val_error: 0.08 + total_macs: 250000 + budget: 250 \ No newline at end of file diff --git a/experiments/kws/model/1d_space.yaml b/experiments/kws/model/1d_space.yaml new file mode 100644 index 00000000..f1b3c9fa --- /dev/null +++ b/experiments/kws/model/1d_space.yaml @@ -0,0 +1,5 @@ +_target_: hannah.models.simple1d.space +name: simple1d_searchspace +num_classes: 12 +max_channels: 256 +max_depth: 4 \ No newline at end of file diff --git a/experiments/progressive_shrinking/config.yaml b/experiments/progressive_shrinking/config.yaml deleted file mode 100644 index 5b5164c0..00000000 --- a/experiments/progressive_shrinking/config.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - base_config - - override dataset: kws - - override features: mfcc - - override model: ofa - - override scheduler: null - - override optimizer: adamw - - override normalizer: fixedpoint - - override module: stream_classifier - - override trainer: default - - override nas: ofa_nas - - _self_ - -experiment_id: progressive_shrinking -module: - shuffle_all_dataloaders: True - -optimizer: - lr: 0.0001 - -dataset: - data_folder: ${hydra:runtime.cwd}/../../datasets/ diff --git a/experiments/progressive_shrinking/experiment/devel.yaml b/experiments/progressive_shrinking/experiment/devel.yaml deleted file mode 100644 index 0dfd03a3..00000000 --- a/experiments/progressive_shrinking/experiment/devel.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# @package _global_ -nas: - _target_: hannah.nas.OFANasTrainer - epochs_warmup: 1 - epochs_kernel_step: 1 - epochs_depth_step: 1 - epochs_width_step: 1 - epochs_dilation_step: 1 - epochs_tuning_step: 1 - elastic_kernels_allowed: true - elastic_depth_allowed: true - elastic_width_allowed: true - elastic_dilation_allowed: false - evaluate: true - random_evaluate: true - random_eval_number: 2 - extract_model_config: false - warmup_model_path: '' - -experiment_id: devel -trainer: - overfit_batches: 1 diff --git a/experiments/progressive_shrinking/experiment/finetune_float.yaml b/experiments/progressive_shrinking/experiment/finetune_float.yaml deleted file mode 100644 index 0ed2ce49..00000000 --- a/experiments/progressive_shrinking/experiment/finetune_float.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# @package _global_ -nas: - _target_: hannah.nas.OFANasTrainer - epochs_warmup: 35 - epochs_kernel_step: 35 - epochs_depth_step: 35 - epochs_width_step: 35 - epochs_dilation_step: 35 - epochs_tuning_step: 5 - elastic_kernels_allowed: true - elastic_depth_allowed: true - elastic_width_allowed: true - elastic_dilation_allowed: false - evaluate: true - random_evaluate: true - random_eval_number: 1000 - extract_model_config: false - warmup_model_path: '' - -experiment_id: finetune_float diff --git a/experiments/progressive_shrinking/experiment/overfit.yaml b/experiments/progressive_shrinking/experiment/overfit.yaml deleted file mode 100644 index bee9a713..00000000 --- a/experiments/progressive_shrinking/experiment/overfit.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# @package _global_ - -# - -nas: - _target_: hannah.nas.OFANasTrainer - epochs_warmup: 5 - epochs_kernel_step: 5 - epochs_depth_step: 0 - epochs_width_step: 0 - epochs_dilation_step: 0 - epochs_tuning_step: 0 - elastic_kernels_allowed: false - elastic_depth_allowed: false - elastic_width_allowed: false - elastic_dilation_allowed: false - evaluate: true - random_evaluate: true - random_eval_number: 2 - extract_model_config: false - warmup_model_path: '' - -experiment_id: overfit diff --git a/experiments/progressive_shrinking/experiment/shrink_float.yaml b/experiments/progressive_shrinking/experiment/shrink_float.yaml deleted file mode 100644 index 7e1bfae9..00000000 --- a/experiments/progressive_shrinking/experiment/shrink_float.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# @package _global_ -nas: - _target_: hannah.nas.OFANasTrainer - epochs_warmup: 35 - epochs_kernel_step: 35 - epochs_depth_step: 35 - epochs_width_step: 35 - epochs_dilation_step: 35 - elastic_kernels_allowed: true - elastic_depth_allowed: true - elastic_width_allowed: true - elastic_dilation_allowed: false - evaluate: true - random_evaluate: true - random_eval_number: 1000 - extract_model_config: false - warmup_model_path: '' - -experiment_id: shrink_float diff --git a/hannah/callbacks/summaries.py b/hannah/callbacks/summaries.py index db493e7c..80c1ded8 100644 --- a/hannah/callbacks/summaries.py +++ b/hannah/callbacks/summaries.py @@ -29,7 +29,7 @@ from tabulate import tabulate from torch.fx.graph_module import GraphModule -from hannah.nas.functional_operators.operators import add, conv2d, linear +from hannah.nas.functional_operators.operators import add, conv2d, linear, conv1d from hannah.nas.graph_conversion import GraphConversionTracer from ..models.factory import qat @@ -337,7 +337,7 @@ def _do_summary(self, pl_module, input=None, print_log=True): msglogger.info( "Estimated Activations: " + "{:,}".format(estimated_acts) ) - except RuntimeError as e: + except (RuntimeError, KeyError) as e: msglogger.warning("Could not create performance summary: %s", str(e)) return OrderedDict() @@ -494,6 +494,7 @@ def __init__(self, module: torch.nn.Module): super().__init__(gm) self.count_function = { + conv1d: get_conv, conv2d: get_conv, linear: get_linear, add: get_zero_op, @@ -514,6 +515,7 @@ def __init__(self, module: torch.nn.Module): def run_node(self, n: torch.fx.Node): try: out = super().run_node(n) + print(out.shape, n) except Exception as e: print(str(e)) if n.op == "call_function": @@ -533,6 +535,7 @@ def run_node(self, n: torch.fx.Node): self.data["MACs"] += [int(macs)] except Exception as e: msglogger.warning("Summary of node %s failed: %s", n.name, str(e)) + print(traceback.format_exc()) return out @@ -540,6 +543,7 @@ class FxMACSummaryCallback(MacSummaryCallback): def _do_summary(self, pl_module, input=None, print_log=True): interpreter = MACSummaryInterpreter(pl_module.model) dummy_input = input + if dummy_input is None: dummy_input = pl_module.example_feature_array dummy_input = dummy_input.to(pl_module.device) diff --git a/hannah/models/embedded_vision_net/blocks.py b/hannah/models/embedded_vision_net/blocks.py index eab2e61c..8cda2605 100644 --- a/hannah/models/embedded_vision_net/blocks.py +++ b/hannah/models/embedded_vision_net/blocks.py @@ -28,10 +28,10 @@ def grouped_pointwise(input, out_channels): @scope def expansion(input, expanded_channels): - pw = partial(pointwise_conv2d, out_channels=expanded_channels) - grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels) - return choice(input, pw, grouped_pw) - # return pointwise_conv2d(input, out_channels=expanded_channels) + #pw = partial(pointwise_conv2d, out_channels=expanded_channels) + #grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels) + #return choice(input, pw, grouped_pw) + return pointwise_conv2d(input, out_channels=expanded_channels) @scope @@ -41,10 +41,10 @@ def spatial_correlation(input, out_channels, kernel_size, stride=1): @scope def reduction(input, out_channels): - pw = partial(pointwise_conv2d, out_channels=out_channels) - grouped_pw = partial(grouped_pointwise, out_channels=out_channels) - return choice(input, pw, grouped_pw) - # return pointwise_conv2d(input, out_channels=out_channels) + #pw = partial(pointwise_conv2d, out_channels=out_channels) + #grouped_pw = partial(grouped_pointwise, out_channels=out_channels) + #return choice(input, pw, grouped_pw) + return pointwise_conv2d(input, out_channels=out_channels) @scope @@ -86,9 +86,9 @@ def pattern(input, stride, out_channels, kernel_size, expand_ratio, reduce_ratio convolution = partial(conv_relu, stride=stride, kernel_size=kernel_size, out_channels=out_channels) red_exp = partial(reduce_expand, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride) exp_red = partial(expand_reduce, out_channels=out_channels, expand_ratio=expand_ratio, kernel_size=kernel_size, stride=stride) - pool = partial(pooling, kernel_size=kernel_size, stride=stride) + #pool = partial(pooling, kernel_size=kernel_size, stride=stride) - out = choice(input, convolution, exp_red, red_exp, pool) + out = choice(input, convolution, exp_red, red_exp) return out diff --git a/hannah/models/simple1d.py b/hannah/models/simple1d.py new file mode 100644 index 00000000..f6e55a4b --- /dev/null +++ b/hannah/models/simple1d.py @@ -0,0 +1,73 @@ +from typing import Any + +from hannah.nas.functional_operators.operators import Relu, Conv1d, Linear, AdaptiveAvgPooling +from hannah.nas.parameters import CategoricalParameter, IntScalarParameter, parametrize +from hannah.nas.functional_operators.op import Tensor, Op, scope, ChoiceOp +from hannah.nas.functional_operators.shapes import conv_shape, padding_expression +from hannah.nas.functional_operators.lazy import lazy + + +import torch + + + +def conv1d(input, out_channels, kernel_size, stride): + in_channels = input.shape()[1] + weight = Tensor(name='weight', + shape=(out_channels, in_channels, kernel_size), + axis=('O', 'I', 'k'), + grad=True) + + conv = Conv1d(stride=stride)(input, weight) + return conv + +def relu(input): + return Relu()(input) + +def adaptive_avg_pooling(input): + return AdaptiveAvgPooling(output_size=1)(input) + +def linear(input, num_classes): + in_features = input.shape()[1] + weight = Tensor(name='weight', + shape=(in_features, num_classes), + axis=('I', 'O'), + grad=True) + return Linear()(input, weight) + +@scope +def conv_relu(input, out_channels, kernel_size, stride): + out = conv1d(input, out_channels=out_channels, stride=stride, kernel_size=kernel_size) + out = relu(out) + return out + +@scope +def classifier_head(input, num_classes): + out = adaptive_avg_pooling(input) + out = linear(out, num_classes) + return out + + +def dynamic_depth(*exits, switch): + return ChoiceOp(*exits, switch=switch)() + +def space(name: str, input, num_classes: int, max_channels=512, max_depth=9): + num_blocks = IntScalarParameter(0, max_depth, name='num_blocks') + exits = [] + + out = input + + for i in range(num_blocks.max+1): + kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size') + stride = CategoricalParameter([1, 2], name='stride') + out_channels = IntScalarParameter(16, max_channels, step_size=8, name='out_channels') + + out = conv_relu(out, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + exits.append(out) + + out = dynamic_depth(*exits, switch=num_blocks) + + out = classifier_head(out, num_classes=num_classes) + + + return out diff --git a/hannah/nas/functional_operators/operators.py b/hannah/nas/functional_operators/operators.py index c009ae22..29d1bbc4 100644 --- a/hannah/nas/functional_operators/operators.py +++ b/hannah/nas/functional_operators/operators.py @@ -9,11 +9,20 @@ from hannah.nas.core.parametrized import is_parametrized from hannah.nas.functional_operators.lazy import lazy from hannah.nas.functional_operators.op import Choice, Op, Tensor -from hannah.nas.functional_operators.shapes import adaptive_average_pooling2d_shape, conv_shape, identity_shape, linear_shape, padding_expression, pool_shape +from hannah.nas.functional_operators.shapes import adaptive_average_pooling_shape, conv_shape, identity_shape, linear_shape, padding_expression, pool_shape from hannah.nas.parameters.parametrize import parametrize from hannah.nas.parameters.parameters import IntScalarParameter, CategoricalParameter +@torch.fx.wrap +def conv1d(input, weight, stride, padding, dilation, groups, *, id): + return F.conv1d(input=input, + weight=weight, + stride=lazy(stride), + padding=lazy(padding), + dilation=lazy(dilation), + groups=lazy(groups)) + @torch.fx.wrap def conv2d(input, weight, stride, padding, dilation, groups, *, id): @@ -53,10 +62,15 @@ def add(input, other, *, id): @torch.fx.wrap -def adaptive_avg_pooling(input, output_size=(1, 1), *, id): +def adaptive_avg_pooling2d(input, output_size=(1, 1), *, id): return F.adaptive_avg_pool2d(input, output_size=output_size) +@torch.fx.wrap +def adaptive_avg_pooling1d(input, output_size=(1, 1), *, id): + return F.adaptive_avg_pool1d(input, output_size=output_size) + + @torch.fx.wrap def max_pool(input, kernel_size, stride, padding, dilation): return F.max_pool2d(input, kernel_size, stride, padding, dilation) @@ -75,17 +89,14 @@ def interleave(input, step_size): @parametrize class Conv1d(Op): - def __init__(self, out_channels, kernel_size=1, stride=1, dilation=1) -> None: + def __init__(self, kernel_size=1, stride=1, dilation=1, groups=1) -> None: super().__init__(name='Conv1d') - self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.dilation = dilation + self.groups = groups self.padding = padding_expression(self.kernel_size, self.stride, self.dilation) - # def _verify_operands(self, operands): - # assert len(operands) == 2 - def __call__(self, *operands) -> Any: new_conv = super().__call__(*operands) input_shape = operands[0].shape() @@ -98,14 +109,13 @@ def __call__(self, *operands) -> Any: new_conv.padding = padding_expression(new_conv.kernel_size, new_conv.stride, new_conv.dilation) return new_conv - # FIXME: Use wrapped implementation def _forward_implementation(self, *operands): x = operands[0] weight = operands[1] - return F.conv1d(input=x, weight=weight, stride=lazy(self.stride), padding=lazy(self.padding), dilation=lazy(self.dilation)) + return conv1d(x, weight, stride=lazy(self.stride), padding=lazy(self.padding), dilation=lazy(self.dilation), groups=lazy(self.groups), id=self.id) def shape_fun(self): - return conv_shape(*self.operands, dims=2, stride=self.stride, padding=self.padding, dilation=self.dilation) + return conv_shape(*self.operands, dims=1, stride=self.stride, padding=self.padding, dilation=self.dilation) @parametrize @@ -117,9 +127,6 @@ def __init__(self, stride=1, dilation=1, groups=1, padding=None) -> None: self.groups = groups self.padding = padding - # def _verify_operands(self, operands): - # assert len(operands) == 2 - def __call__(self, *operands) -> Any: new_conv = super().__call__(*operands) input_shape = operands[0].shape() @@ -131,15 +138,7 @@ def __call__(self, *operands) -> Any: new_conv.kernel_size = weight_shape[2] if self.padding is None: new_conv.padding = padding_expression(new_conv.kernel_size, new_conv.stride, new_conv.dilation) - - # new_conv.weight = Tensor(name=self.id + '.weight', - # shape=(self.out_channels, new_conv.in_channels, self.kernel_size, self.kernel_size), - # axis=('O', 'I', 'kH', 'kW')) - - # new_conv.operands.append(new_conv.weight) - # if is_parametrized(new_conv.weight): - # new_conv._PARAMETERS[new_conv.weight.name] = new_conv.weight - # new_conv._verify_operands(new_conv.operands) + return new_conv def _forward_implementation(self, x, weight): @@ -293,12 +292,19 @@ class AdaptiveAvgPooling(Op): def __init__(self, output_size=(1, 1)) -> None: super().__init__(name='AvgPooling') self.output_size = output_size + if isinstance(output_size, int): + self.dim = 1 + else: + self.dim = len(output_size) def shape_fun(self): - return adaptive_average_pooling2d_shape(*self.operands, output_size=self.output_size) + return adaptive_average_pooling_shape(*self.operands, output_size=self.output_size) def _forward_implementation(self, *operands): - return adaptive_avg_pooling(operands[0], output_size=self.output_size, id=self.id) + if self.dim == 1: + return adaptive_avg_pooling1d(operands[0], output_size=self.output_size, id=self.id) + else: + return adaptive_avg_pooling2d(operands[0], output_size=self.output_size, id=self.id) @parametrize diff --git a/hannah/nas/functional_operators/shapes.py b/hannah/nas/functional_operators/shapes.py index a098ecd2..07e1abc8 100644 --- a/hannah/nas/functional_operators/shapes.py +++ b/hannah/nas/functional_operators/shapes.py @@ -77,11 +77,14 @@ def linear_shape(*operands): return (batch, out_features) -def adaptive_average_pooling2d_shape(*operands, output_size): +def adaptive_average_pooling_shape(*operands, output_size): dims = operands[0].shape() # NOTE: dims might be SymbolicSequence. Symbolic sequence has in its symbolic state no fixed length, making it # necessary to know and define which dimensions hold values. new_dims = [dims[0], dims[1]] - new_dims += [output_size[0], output_size[1]] + if isinstance(output_size, int): + output_size = [output_size] + + new_dims.extend(output_size) return tuple(new_dims) diff --git a/hannah/nas/graph_conversion.py b/hannah/nas/graph_conversion.py index 294166b3..a0bcbd96 100644 --- a/hannah/nas/graph_conversion.py +++ b/hannah/nas/graph_conversion.py @@ -112,6 +112,7 @@ def __init__(self, module, garbage_collect_values=True): torch.nn.MaxPool2d: self.add_nodes_pooling, torch.nn.AvgPool2d: self.add_nodes_pooling, "add": self.add_nodes_add, + "conv1d": self.add_nodes_conv_fun, "conv2d": self.add_nodes_conv_fun, "linear": self.add_nodes_linear_fun, "relu": self.add_nodes_relu, @@ -564,7 +565,8 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: target_name, None, args, kwargs, output_tensor ) else: - output = output_tensor + output = NamedTensor(target_name, output_tensor) + return output diff --git a/hannah/nas/search/model_trainer/simple_model_trainer.py b/hannah/nas/search/model_trainer/simple_model_trainer.py index 258b1123..de54f13a 100644 --- a/hannah/nas/search/model_trainer/simple_model_trainer.py +++ b/hannah/nas/search/model_trainer/simple_model_trainer.py @@ -92,7 +92,7 @@ def run_training(self, model, num, global_num, config): except Exception as e: msglogger.critical("Training failed with exception") msglogger.critical(str(e)) - # print(traceback.format_exc()) + print(traceback.format_exc()) # sys.exit(1) res = {} diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index 3b849faa..353c4e4e 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -55,6 +55,9 @@ def __init__( constraint_model=None, parent_config=None, random_state=None, + input_shape = None, + *args, + **kwargs, ) -> None: self.budget = budget self.n_jobs = n_jobs @@ -68,6 +71,10 @@ def __init__( self.random_state = np.random.RandomState() else: self.random_state = random_state + + self.example_input_array = None + if input_shape is not None: + self.example_input_array = torch.rand([1] + list(input_shape)) def run(self): self.before_search() @@ -128,7 +135,7 @@ def before_search(self): ) self.mac_predictor = MACPredictor(predictor="fx") self.model_trainer = instantiate(self.config.nas.model_trainer) - if "predictor" in self.config.nas: + if "predictor" in self.config.nas and self.config.nas.predictor is not None: self.predictor = instantiate(self.config.nas.predictor, _recursive_=False) if os.path.exists("performance_data"): self.predictor.load("performance_data") @@ -249,11 +256,12 @@ def build_model(self, parameters): return module def build_search_space(self): - # FIXME: In the future, get num_labels also from dataset - # search_space = instantiate(self.config.model, input_shape=self.example_input_array.shape, _recursive_=True) + + input = Tensor( "input", shape=self.example_input_array.shape, axis=("N", "C", "H", "W") ) + search_space = instantiate(self.config.model, input=input, _recursive_=True) return search_space @@ -288,7 +296,8 @@ def initialize_dataset(self): self.val_set = val_set self.unlabeled_set = unlabeled_set self.test_set = test_set - self.example_input_array = torch.rand([1] + list(train_set.size())) + if self.example_input_array is None: + self.example_input_array = torch.rand([1] + list(train_set.size())) def train_model(self, model): trainer = instantiate(self.config.trainer, callbacks=self.callbacks)