Skip to content

Commit

Permalink
Revert last three commits
Browse files Browse the repository at this point in the history
  • Loading branch information
moreib committed Dec 10, 2024
1 parent e2e6555 commit 62cb3a6
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 95 deletions.
1 change: 0 additions & 1 deletion hannah/conf/nas/aging_evolution_nas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ n_jobs: 10
presample: False
total_candidates: 50
num_selected_candidates: 20
constrained_sampling_on_search: True
bounds:
val_error: 0.1
# total_macs: 128000000
Expand Down
63 changes: 31 additions & 32 deletions hannah/nas/constraints/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from hannah.nas.parameters.parameters import Parameter
from hannah.nas.parameters.parametrize import set_parametrization
from hannah.nas.search.utils import np_to_primitive
from hannah.nas.functional_operators.utils import get_active_parameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,36 +72,36 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False):
}


# def get_active_parameter(net):
# active_param_ids = []
# queue = [net]
# visited = [net.id]

# def extract_parameters(node):
# ids = []
# for k, p in node._PARAMETERS.items():
# if isinstance(p, Parameter):
# ids.append(p.id)
# return ids

# while queue:
# current = queue.pop()
# if isinstance(current, ChoiceOp):
# # handle choices
# active_param_ids.append(current.switch.id)
# chosen_path = current.options[lazy(current.switch)]
# if chosen_path.id not in visited:
# queue.append(chosen_path)
# visited.append(chosen_path.id)
# else:
# # handle all other operators & tensors
# active_param_ids.extend(extract_parameters(current))
# for operand in current.operands:
# if operand.id not in visited:
# queue.append(operand)
# visited.append(operand.id)

# return active_param_ids
def get_active_parameter(net):
active_param_ids = []
queue = [net]
visited = [net.id]

def extract_parameters(node):
ids = []
for k, p in node._PARAMETERS.items():
if isinstance(p, Parameter):
ids.append(p.id)
return ids

while queue:
current = queue.pop()
if isinstance(current, ChoiceOp):
# handle choices
active_param_ids.append(current.switch.id)
chosen_path = current.options[lazy(current.switch)]
if chosen_path.id not in visited:
queue.append(chosen_path)
visited.append(chosen_path.id)
else:
# handle all other operators & tensors
active_param_ids.extend(extract_parameters(current))
for operand in current.operands:
if operand.id not in visited:
queue.append(operand)
visited.append(operand.id)

return active_param_ids


class RandomWalkConstraintSolver:
Expand Down Expand Up @@ -194,7 +193,7 @@ def solve(self, module, parameters, fix_vars=[]):
ct = 0
while ct < self.max_iterations:
# active_params = get_active_parameter(params)
active_params = list(get_active_parameters(mod).keys())
active_params = get_active_parameter(mod)

param_keys = [p for p in all_param_keys if p in active_params]
current = con.lhs.evaluate()
Expand Down
30 changes: 0 additions & 30 deletions hannah/nas/functional_operators/utils.py

This file was deleted.

13 changes: 4 additions & 9 deletions hannah/nas/search/sampler/aging_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from hannah.nas.search.sampler.mutator import ParameterMutator
from hannah.nas.search.utils import np_to_primitive

from ...parametrization import SearchSpace
from ...utils import is_pareto
from .base_sampler import Sampler, SearchResult
from hannah.nas.functional_operators.utils import get_active_parameters


class FitnessFunction:
Expand All @@ -59,16 +59,14 @@ class AgingEvolutionSampler(Sampler):
def __init__(
self,
parent_config,
search_space,
parametrization: dict,
population_size: int = 50,
random_state = None,
sample_size: int = 10,
mutation_rate: float = 0.01,
eps: float = 0.1,
output_folder=".",
):
super().__init__(parent_config, search_space=search_space, output_folder=output_folder)
super().__init__(parent_config, output_folder=output_folder)
self.bounds = self.parent_config.nas.bounds
self.parametrization = parametrization

Expand All @@ -81,7 +79,7 @@ def __init__(
self.population_size = population_size
self.sample_size = sample_size
self.eps = eps
self.mutator = ParameterMutator(mutation_rate)
self.mutator = ParameterMutator(0.1)

self.history = []
self.population = []
Expand Down Expand Up @@ -120,11 +118,8 @@ def next_parameters(self):

parent = sample[np.argmin(fitness)]
parent_parametrization = set_parametrization(parent.parameters, self.parametrization)
parametrization = {key: param.current_value for key, param in parent_parametrization.items()}
active_parameters = get_active_parameters(self.search_space, parent_parametrization)

mutated_parameters, mutated_keys = self.mutator.mutate(active_parameters)
parametrization.update(mutated_parameters)
parametrization, mutated_keys = self.mutator.mutate(parent_parametrization)

return parametrization, mutated_keys

Expand Down
2 changes: 0 additions & 2 deletions hannah/nas/search/sampler/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def costs(self):
class Sampler(ABC):
def __init__(self,
parent_config,
search_space,
output_folder=".") -> None:
self.history = []
self.search_space = search_space
self.output_folder = Path(output_folder)
self.parent_config = parent_config

Expand Down
3 changes: 1 addition & 2 deletions hannah/nas/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def before_search(self):
parametrization = self.search_space.parametrization(flatten=True)
self.sampler = instantiate(
self.config.nas.sampler,
search_space=self.search_space,
parametrization=parametrization,
parent_config=self.config,
_recursive_=False,
Expand Down Expand Up @@ -248,7 +247,7 @@ def sample_candidates(
num_candidates=None,
sort_key="val_error",
presample=False,
constrain=True,
constrain=False,
):
candidates = []
skip_ct = 0
Expand Down
19 changes: 0 additions & 19 deletions hannah/nas/test/test_active_parameters.py

This file was deleted.

0 comments on commit 62cb3a6

Please sign in to comment.