From 62cb3a615d268b6607ec7bc11f2e53f03808cd40 Mon Sep 17 00:00:00 2001 From: Moritz Reiber Date: Tue, 10 Dec 2024 09:30:03 +0100 Subject: [PATCH] Revert last three commits --- hannah/conf/nas/aging_evolution_nas.yaml | 1 - hannah/nas/constraints/random_walk.py | 63 ++++++++++---------- hannah/nas/functional_operators/utils.py | 30 ---------- hannah/nas/search/sampler/aging_evolution.py | 13 ++-- hannah/nas/search/sampler/base_sampler.py | 2 - hannah/nas/search/search.py | 3 +- hannah/nas/test/test_active_parameters.py | 19 ------ 7 files changed, 36 insertions(+), 95 deletions(-) delete mode 100644 hannah/nas/functional_operators/utils.py delete mode 100644 hannah/nas/test/test_active_parameters.py diff --git a/hannah/conf/nas/aging_evolution_nas.yaml b/hannah/conf/nas/aging_evolution_nas.yaml index c4bef1ab..25372389 100644 --- a/hannah/conf/nas/aging_evolution_nas.yaml +++ b/hannah/conf/nas/aging_evolution_nas.yaml @@ -31,7 +31,6 @@ n_jobs: 10 presample: False total_candidates: 50 num_selected_candidates: 20 -constrained_sampling_on_search: True bounds: val_error: 0.1 # total_macs: 128000000 diff --git a/hannah/nas/constraints/random_walk.py b/hannah/nas/constraints/random_walk.py index 34d4bdb6..47fad65f 100644 --- a/hannah/nas/constraints/random_walk.py +++ b/hannah/nas/constraints/random_walk.py @@ -29,7 +29,6 @@ from hannah.nas.parameters.parameters import Parameter from hannah.nas.parameters.parametrize import set_parametrization from hannah.nas.search.utils import np_to_primitive -from hannah.nas.functional_operators.utils import get_active_parameters logger = logging.getLogger(__name__) @@ -73,36 +72,36 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False): } -# def get_active_parameter(net): -# active_param_ids = [] -# queue = [net] -# visited = [net.id] - -# def extract_parameters(node): -# ids = [] -# for k, p in node._PARAMETERS.items(): -# if isinstance(p, Parameter): -# ids.append(p.id) -# return ids - -# while queue: -# current = queue.pop() -# if isinstance(current, ChoiceOp): -# # handle choices -# active_param_ids.append(current.switch.id) -# chosen_path = current.options[lazy(current.switch)] -# if chosen_path.id not in visited: -# queue.append(chosen_path) -# visited.append(chosen_path.id) -# else: -# # handle all other operators & tensors -# active_param_ids.extend(extract_parameters(current)) -# for operand in current.operands: -# if operand.id not in visited: -# queue.append(operand) -# visited.append(operand.id) - -# return active_param_ids +def get_active_parameter(net): + active_param_ids = [] + queue = [net] + visited = [net.id] + + def extract_parameters(node): + ids = [] + for k, p in node._PARAMETERS.items(): + if isinstance(p, Parameter): + ids.append(p.id) + return ids + + while queue: + current = queue.pop() + if isinstance(current, ChoiceOp): + # handle choices + active_param_ids.append(current.switch.id) + chosen_path = current.options[lazy(current.switch)] + if chosen_path.id not in visited: + queue.append(chosen_path) + visited.append(chosen_path.id) + else: + # handle all other operators & tensors + active_param_ids.extend(extract_parameters(current)) + for operand in current.operands: + if operand.id not in visited: + queue.append(operand) + visited.append(operand.id) + + return active_param_ids class RandomWalkConstraintSolver: @@ -194,7 +193,7 @@ def solve(self, module, parameters, fix_vars=[]): ct = 0 while ct < self.max_iterations: # active_params = get_active_parameter(params) - active_params = list(get_active_parameters(mod).keys()) + active_params = get_active_parameter(mod) param_keys = [p for p in all_param_keys if p in active_params] current = con.lhs.evaluate() diff --git a/hannah/nas/functional_operators/utils.py b/hannah/nas/functional_operators/utils.py deleted file mode 100644 index da422789..00000000 --- a/hannah/nas/functional_operators/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -from hannah.nas.functional_operators.op import ChoiceOp, Tensor -from hannah.models.embedded_vision_net.models import embedded_vision_net -from hannah.nas.parameters.parameters import Parameter -from hannah.nas.core.expression import Expression - - -def get_active_parameters(space, parametrization=None): - if parametrization is None: - parametrization = space.parametrization() - - queue = [space] - visited = [space.id] - active_params = {} - - while queue: - node = queue.pop(0) - for k, p in node._PARAMETERS.items(): - if isinstance(p, Parameter): - active_params[p.id] = parametrization[p.id] - for operand in node.operands: - while isinstance(operand, ChoiceOp): - for k, p in operand._PARAMETERS.items(): - if isinstance(p, Parameter): - active_params[p.id] = parametrization[p.id] - active_op_index = operand.switch.evaluate() - operand = operand.operands[active_op_index] - if operand.id not in visited: - queue.append(operand) - visited.append(operand.id) - return active_params diff --git a/hannah/nas/search/sampler/aging_evolution.py b/hannah/nas/search/sampler/aging_evolution.py index 8cc75dfc..70e97804 100644 --- a/hannah/nas/search/sampler/aging_evolution.py +++ b/hannah/nas/search/sampler/aging_evolution.py @@ -30,9 +30,9 @@ from hannah.nas.search.sampler.mutator import ParameterMutator from hannah.nas.search.utils import np_to_primitive +from ...parametrization import SearchSpace from ...utils import is_pareto from .base_sampler import Sampler, SearchResult -from hannah.nas.functional_operators.utils import get_active_parameters class FitnessFunction: @@ -59,16 +59,14 @@ class AgingEvolutionSampler(Sampler): def __init__( self, parent_config, - search_space, parametrization: dict, population_size: int = 50, random_state = None, sample_size: int = 10, - mutation_rate: float = 0.01, eps: float = 0.1, output_folder=".", ): - super().__init__(parent_config, search_space=search_space, output_folder=output_folder) + super().__init__(parent_config, output_folder=output_folder) self.bounds = self.parent_config.nas.bounds self.parametrization = parametrization @@ -81,7 +79,7 @@ def __init__( self.population_size = population_size self.sample_size = sample_size self.eps = eps - self.mutator = ParameterMutator(mutation_rate) + self.mutator = ParameterMutator(0.1) self.history = [] self.population = [] @@ -120,11 +118,8 @@ def next_parameters(self): parent = sample[np.argmin(fitness)] parent_parametrization = set_parametrization(parent.parameters, self.parametrization) - parametrization = {key: param.current_value for key, param in parent_parametrization.items()} - active_parameters = get_active_parameters(self.search_space, parent_parametrization) - mutated_parameters, mutated_keys = self.mutator.mutate(active_parameters) - parametrization.update(mutated_parameters) + parametrization, mutated_keys = self.mutator.mutate(parent_parametrization) return parametrization, mutated_keys diff --git a/hannah/nas/search/sampler/base_sampler.py b/hannah/nas/search/sampler/base_sampler.py index b29b4c4b..749598e4 100644 --- a/hannah/nas/search/sampler/base_sampler.py +++ b/hannah/nas/search/sampler/base_sampler.py @@ -27,10 +27,8 @@ def costs(self): class Sampler(ABC): def __init__(self, parent_config, - search_space, output_folder=".") -> None: self.history = [] - self.search_space = search_space self.output_folder = Path(output_folder) self.parent_config = parent_config diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index 3605f2ee..a1f8c011 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -129,7 +129,6 @@ def before_search(self): parametrization = self.search_space.parametrization(flatten=True) self.sampler = instantiate( self.config.nas.sampler, - search_space=self.search_space, parametrization=parametrization, parent_config=self.config, _recursive_=False, @@ -248,7 +247,7 @@ def sample_candidates( num_candidates=None, sort_key="val_error", presample=False, - constrain=True, + constrain=False, ): candidates = [] skip_ct = 0 diff --git a/hannah/nas/test/test_active_parameters.py b/hannah/nas/test/test_active_parameters.py deleted file mode 100644 index b61beb8f..00000000 --- a/hannah/nas/test/test_active_parameters.py +++ /dev/null @@ -1,19 +0,0 @@ -from hannah.nas.functional_operators.op import Tensor -from hannah.models.embedded_vision_net.models import embedded_vision_net -from hannah.nas.functional_operators.utils import get_active_parameters - - -def test_active_parameters(): - input = Tensor(name="input", shape=(1, 3, 32, 32), axis=("N", "C", "H", "W")) - space = embedded_vision_net("space", input, num_classes=10) - space.parametrization()["embedded_vision_net_0.ChoiceOp_0.num_blocks"].set_current(1) - space.parametrization()["embedded_vision_net_0.block_0.pattern_0.ChoiceOp_0.choice"].set_current(4) - space.parametrization()["embedded_vision_net_0.block_0.pattern_0.sandglass_block_0.expansion_0.ChoiceOp_0.choice"].set_current(1) - active_params = get_active_parameters(space) - - space.parametrization() - print() - - -if __name__ == "__main__": - test_active_parameters() \ No newline at end of file