Skip to content

Commit

Permalink
Only use active parameters in AE Search
Browse files Browse the repository at this point in the history
  • Loading branch information
moreib committed Dec 10, 2024
1 parent 62cb3a6 commit 78fb5b4
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 46 deletions.
1 change: 1 addition & 0 deletions hannah/conf/nas/aging_evolution_nas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ n_jobs: 10
presample: False
total_candidates: 50
num_selected_candidates: 20
constrained_sampling_on_search: True
bounds:
val_error: 0.1
# total_macs: 128000000
Expand Down
63 changes: 32 additions & 31 deletions hannah/nas/constraints/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from hannah.nas.parameters.parameters import Parameter
from hannah.nas.parameters.parametrize import set_parametrization
from hannah.nas.search.utils import np_to_primitive
from hannah.nas.functional_operators.utils.visit import get_active_parameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,36 +73,36 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False):
}


def get_active_parameter(net):
active_param_ids = []
queue = [net]
visited = [net.id]

def extract_parameters(node):
ids = []
for k, p in node._PARAMETERS.items():
if isinstance(p, Parameter):
ids.append(p.id)
return ids

while queue:
current = queue.pop()
if isinstance(current, ChoiceOp):
# handle choices
active_param_ids.append(current.switch.id)
chosen_path = current.options[lazy(current.switch)]
if chosen_path.id not in visited:
queue.append(chosen_path)
visited.append(chosen_path.id)
else:
# handle all other operators & tensors
active_param_ids.extend(extract_parameters(current))
for operand in current.operands:
if operand.id not in visited:
queue.append(operand)
visited.append(operand.id)

return active_param_ids
# def get_active_parameter(net):
# active_param_ids = []
# queue = [net]
# visited = [net.id]

# def extract_parameters(node):
# ids = []
# for k, p in node._PARAMETERS.items():
# if isinstance(p, Parameter):
# ids.append(p.id)
# return ids

# while queue:
# current = queue.pop()
# if isinstance(current, ChoiceOp):
# # handle choices
# active_param_ids.append(current.switch.id)
# chosen_path = current.options[lazy(current.switch)]
# if chosen_path.id not in visited:
# queue.append(chosen_path)
# visited.append(chosen_path.id)
# else:
# # handle all other operators & tensors
# active_param_ids.extend(extract_parameters(current))
# for operand in current.operands:
# if operand.id not in visited:
# queue.append(operand)
# visited.append(operand.id)

# return active_param_ids


class RandomWalkConstraintSolver:
Expand Down Expand Up @@ -193,7 +194,7 @@ def solve(self, module, parameters, fix_vars=[]):
ct = 0
while ct < self.max_iterations:
# active_params = get_active_parameter(params)
active_params = get_active_parameter(mod)
active_params = list(get_active_parameters(mod).keys())

param_keys = [p for p in all_param_keys if p in active_params]
current = con.lhs.evaluate()
Expand Down
29 changes: 29 additions & 0 deletions hannah/nas/functional_operators/utils/visit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# limitations under the License.
#
from ..op import BaseNode
from hannah.nas.functional_operators.op import ChoiceOp
from hannah.nas.parameters.parameters import Parameter


def post_order(op: BaseNode):
Expand All @@ -40,3 +42,30 @@ def post_order(op: BaseNode):
def reverse_post_order(op: BaseNode):
"""Visits the operator graph in reverse post order"""
return reversed(list(post_order(op)))


def get_active_parameters(space, parametrization=None):
if parametrization is None:
parametrization = space.parametrization()

queue = [space]
visited = [space.id]
active_params = {}

while queue:
node = queue.pop(0)
for k, p in node._PARAMETERS.items():
if isinstance(p, Parameter):
active_params[p.id] = parametrization[p.id]
for operand in node.operands:
while isinstance(operand, ChoiceOp):
for k, p in operand._PARAMETERS.items():
if isinstance(p, Parameter):
active_params[p.id] = parametrization[p.id]
active_op_index = operand.switch.evaluate()
operand = operand.operands[active_op_index]
if operand.id not in visited:
queue.append(operand)
visited.append(operand.id)
return active_params

13 changes: 9 additions & 4 deletions hannah/nas/search/sampler/aging_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from hannah.nas.search.sampler.mutator import ParameterMutator
from hannah.nas.search.utils import np_to_primitive

from ...parametrization import SearchSpace
from ...utils import is_pareto
from .base_sampler import Sampler, SearchResult
from hannah.nas.functional_operators.utils.visit import get_active_parameters


class FitnessFunction:
Expand All @@ -59,14 +59,16 @@ class AgingEvolutionSampler(Sampler):
def __init__(
self,
parent_config,
search_space,
parametrization: dict,
population_size: int = 50,
random_state = None,
sample_size: int = 10,
mutation_rate: float = 0.01,
eps: float = 0.1,
output_folder=".",
):
super().__init__(parent_config, output_folder=output_folder)
super().__init__(parent_config, search_space=search_space, output_folder=output_folder)
self.bounds = self.parent_config.nas.bounds
self.parametrization = parametrization

Expand All @@ -79,7 +81,7 @@ def __init__(
self.population_size = population_size
self.sample_size = sample_size
self.eps = eps
self.mutator = ParameterMutator(0.1)
self.mutator = ParameterMutator(mutation_rate)

self.history = []
self.population = []
Expand Down Expand Up @@ -118,8 +120,11 @@ def next_parameters(self):

parent = sample[np.argmin(fitness)]
parent_parametrization = set_parametrization(parent.parameters, self.parametrization)
parametrization = {key: param.current_value for key, param in parent_parametrization.items()}
active_parameters = get_active_parameters(self.search_space, parent_parametrization)

parametrization, mutated_keys = self.mutator.mutate(parent_parametrization)
mutated_parameters, mutated_keys = self.mutator.mutate(active_parameters)
parametrization.update(mutated_parameters)

return parametrization, mutated_keys

Expand Down
2 changes: 2 additions & 0 deletions hannah/nas/search/sampler/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def costs(self):
class Sampler(ABC):
def __init__(self,
parent_config,
search_space,
output_folder=".") -> None:
self.history = []
self.search_space = search_space
self.output_folder = Path(output_folder)
self.parent_config = parent_config

Expand Down
3 changes: 2 additions & 1 deletion hannah/nas/search/sampler/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ class RandomSampler(Sampler):
def __init__(
self,
parent_config,
search_space,
parametrization,
output_folder=".",
) -> None:
super().__init__(parent_config=parent_config, output_folder=output_folder)
super().__init__(parent_config=parent_config, search_space=search_space, output_folder=output_folder)
self.parametrization = parametrization

if (self.output_folder / "history.yml").exists():
Expand Down
3 changes: 2 additions & 1 deletion hannah/nas/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def before_search(self):
parametrization = self.search_space.parametrization(flatten=True)
self.sampler = instantiate(
self.config.nas.sampler,
search_space=self.search_space,
parametrization=parametrization,
parent_config=self.config,
_recursive_=False,
Expand Down Expand Up @@ -247,7 +248,7 @@ def sample_candidates(
num_candidates=None,
sort_key="val_error",
presample=False,
constrain=False,
constrain=True,
):
candidates = []
skip_ct = 0
Expand Down
19 changes: 19 additions & 0 deletions hannah/nas/test/test_active_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from hannah.nas.functional_operators.op import Tensor
from hannah.models.embedded_vision_net.models import embedded_vision_net
from hannah.nas.functional_operators.utils.visit import get_active_parameters


def test_active_parameters():
input = Tensor(name="input", shape=(1, 3, 32, 32), axis=("N", "C", "H", "W"))
space = embedded_vision_net("space", input, num_classes=10)
space.parametrization()["embedded_vision_net_0.ChoiceOp_0.num_blocks"].set_current(1)
space.parametrization()["embedded_vision_net_0.block_0.pattern_0.ChoiceOp_0.choice"].set_current(4)
space.parametrization()["embedded_vision_net_0.block_0.pattern_0.sandglass_block_0.expansion_0.ChoiceOp_0.choice"].set_current(1)
active_params = get_active_parameters(space)

space.parametrization()
print()


if __name__ == "__main__":
test_active_parameters()
2 changes: 1 addition & 1 deletion hannah/nas/test/test_max78000_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_graph(seed):
warnings.warn("remove this when seedable randomsampling works")

print("Init sampler")
sampler = RandomSampler(None, graph.parametrization(flatten=True))
sampler = RandomSampler(None, graph, graph.parametrization(flatten=True))

print("Init solver")
solver = RandomWalkConstraintSolver()
Expand Down
2 changes: 1 addition & 1 deletion hannah/nas/test/test_nn_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_nn_meter(hardware_name):
predictor = NNMeterPredictor(hardware_name)

print("Init sampler")
sampler = RandomSampler(None, net.parametrization(flatten=True))
sampler = RandomSampler(None, net, net.parametrization(flatten=True))

print("Init solver")
solver = RandomWalkConstraintSolver()
Expand Down
4 changes: 2 additions & 2 deletions hannah/nas/test/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_export_embedded_vision_net():
print(graph)

print("Init sampler")
sampler = RandomSampler(None, graph.parametrization(flatten=True))
sampler = RandomSampler(None, graph, graph.parametrization(flatten=True))

print("Init solver")
solver = RandomWalkConstraintSolver()
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_export_ai8x_net():
print(graph)

print("Init sampler")
sampler = RandomSampler(None, graph.parametrization(flatten=True))
sampler = RandomSampler(None, graph, graph.parametrization(flatten=True))

print("Init solver")
solver = RandomWalkConstraintSolver()
Expand Down
9 changes: 5 additions & 4 deletions hannah/nas/test/test_random_walk_constrainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from hannah.nas.functional_operators.op import Tensor, scope, search_space
from hannah.nas.constraints.random_walk import get_active_parameter, RandomWalkConstraintSolver
from hannah.nas.functional_operators.utils.visit import get_active_parameters
from hannah.nas.constraints.random_walk import RandomWalkConstraintSolver
from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv_relu, dynamic_depth, linear
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter

Expand Down Expand Up @@ -44,17 +45,17 @@ def space(input):
def test_get_active_params():
input = Tensor(name='input', shape=(1, 3, 32, 32), axis=('N', 'C', 'H', 'W'))
out = space(input)
active_params = get_active_parameter(out)
active_params = list(get_active_parameters(out).keys())
assert len(active_params) == 7
for p in active_params:
assert "parallel_blocks_1" not in p and "parallel_blocks_2" not in p
out.parametrization()['space_0.ChoiceOp_0.depth'].set_current(1)
active_params = get_active_parameter(out)
active_params = get_active_parameters(out)
assert len(active_params) == 10
for p in active_params:
assert "parallel_blocks_2" not in p
out.parametrization()['space_0.ChoiceOp_0.depth'].set_current(2)
active_params = get_active_parameter(out)
active_params = get_active_parameters(out)
assert len(active_params) == 13


Expand Down
2 changes: 1 addition & 1 deletion hannah/nas/test/test_searchspace_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_model_conversion(model):
model.sample()

print("Init sampler")
sampler = RandomSampler(None, model.parametrization(flatten=True))
sampler = RandomSampler(None, model, model.parametrization(flatten=True))

print("Init solver")
solver = RandomWalkConstraintSolver()
Expand Down

0 comments on commit 78fb5b4

Please sign in to comment.