diff --git a/hannah/nas/performance_prediction/simple.py b/hannah/nas/performance_prediction/simple.py index 4a27e01e..dbaebc67 100644 --- a/hannah/nas/performance_prediction/simple.py +++ b/hannah/nas/performance_prediction/simple.py @@ -140,7 +140,7 @@ def predict(self, model, input): print(result, std_dev) - metrics = {'val_error': result} + metrics = {'val_error': result.item()} logger.info("Predicted performance metrics") for k in metrics.keys(): diff --git a/hannah/nas/search/sampler/aging_evolution.py b/hannah/nas/search/sampler/aging_evolution.py index 1096e515..874fe7d0 100644 --- a/hannah/nas/search/sampler/aging_evolution.py +++ b/hannah/nas/search/sampler/aging_evolution.py @@ -38,6 +38,8 @@ class FitnessFunction: def __init__(self, bounds, random_state): self.bounds = bounds + if random_state is None: + random_state = np.random.RandomState() self.lambdas = random_state.uniform(low=0.0, high=1.0, size=len(self.bounds)) def __call__(self, values): diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index 34bf4f11..8551a893 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -21,6 +21,7 @@ import os import traceback from abc import ABC, abstractmethod +import numpy as np import torch from hydra.utils import get_class, instantiate @@ -34,6 +35,8 @@ from hannah.nas.search.utils import WorklistItem, save_config_to_file from hannah.utils.utils import common_callbacks from hannah.nas.graph_conversion import model_to_graph + +from hannah.nas.search.sampler.aging_evolution import FitnessFunction import traceback import copy @@ -51,6 +54,7 @@ def __init__( predictor=None, constraint_model=None, parent_config=None, + random_state=None, ) -> None: self.budget = budget self.n_jobs = n_jobs @@ -60,6 +64,10 @@ def __init__( self.model_trainer = model_trainer self.predictor = predictor self.constraint_model = constraint_model + if random_state is None: + self.random_state = np.random.RandomState() + else: + self.random_state = random_state def run(self): self.before_search() @@ -84,6 +92,14 @@ def add_model_trainer(self, trainer): def add_sampler(self, sampler): self.sampler = sampler + def get_fitness_function(self): + # FIXME: make better configurable + if hasattr(self, 'bounds') and self.bounds is not None: + bounds = self.bounds + return FitnessFunction(bounds, self.random_state) + else: + return lambda x: x['val_error'] + class DirectNAS(NASBase): def __init__( @@ -192,7 +208,7 @@ def after_search(self): pass # self.extract_best_model() - def sample_candidates(self, num_total, num_candidates=None, sort_key="val_error"): + def sample_candidates(self, num_total, num_candidates=None, sort_key="ff"): candidates = [] for n in range(num_total): models = [] @@ -200,6 +216,8 @@ def sample_candidates(self, num_total, num_candidates=None, sort_key="val_error" model = self.build_model(parameters) models.append(model) estimated_metrics, satisfied_bounds = self.estimate_metrics(copy.deepcopy(model)) + ff = self.get_fitness_function()(estimated_metrics) + estimated_metrics['ff'] = ff candidates.append((model, parameters, estimated_metrics, satisfied_bounds)) if self.predictor: