Skip to content

Commit

Permalink
Rank by fitness function
Browse files Browse the repository at this point in the history
  • Loading branch information
moreib committed Nov 6, 2023
1 parent 3d78e33 commit 7a67846
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hannah/nas/performance_prediction/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def predict(self, model, input):

print(result, std_dev)

metrics = {'val_error': result}
metrics = {'val_error': result.item()}

logger.info("Predicted performance metrics")
for k in metrics.keys():
Expand Down
2 changes: 2 additions & 0 deletions hannah/nas/search/sampler/aging_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
class FitnessFunction:
def __init__(self, bounds, random_state):
self.bounds = bounds
if random_state is None:
random_state = np.random.RandomState()
self.lambdas = random_state.uniform(low=0.0, high=1.0, size=len(self.bounds))

def __call__(self, values):
Expand Down
20 changes: 19 additions & 1 deletion hannah/nas/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import traceback
from abc import ABC, abstractmethod
import numpy as np

import torch
from hydra.utils import get_class, instantiate
Expand All @@ -34,6 +35,8 @@
from hannah.nas.search.utils import WorklistItem, save_config_to_file
from hannah.utils.utils import common_callbacks
from hannah.nas.graph_conversion import model_to_graph

from hannah.nas.search.sampler.aging_evolution import FitnessFunction
import traceback
import copy

Expand All @@ -51,6 +54,7 @@ def __init__(
predictor=None,
constraint_model=None,
parent_config=None,
random_state=None,
) -> None:
self.budget = budget
self.n_jobs = n_jobs
Expand All @@ -60,6 +64,10 @@ def __init__(
self.model_trainer = model_trainer
self.predictor = predictor
self.constraint_model = constraint_model
if random_state is None:
self.random_state = np.random.RandomState()
else:
self.random_state = random_state

def run(self):
self.before_search()
Expand All @@ -84,6 +92,14 @@ def add_model_trainer(self, trainer):
def add_sampler(self, sampler):
self.sampler = sampler

def get_fitness_function(self):
# FIXME: make better configurable
if hasattr(self, 'bounds') and self.bounds is not None:
bounds = self.bounds
return FitnessFunction(bounds, self.random_state)
else:
return lambda x: x['val_error']


class DirectNAS(NASBase):
def __init__(
Expand Down Expand Up @@ -192,14 +208,16 @@ def after_search(self):
pass
# self.extract_best_model()

def sample_candidates(self, num_total, num_candidates=None, sort_key="val_error"):
def sample_candidates(self, num_total, num_candidates=None, sort_key="ff"):
candidates = []
for n in range(num_total):
models = []
parameters = self.sample()
model = self.build_model(parameters)
models.append(model)
estimated_metrics, satisfied_bounds = self.estimate_metrics(copy.deepcopy(model))
ff = self.get_fitness_function()(estimated_metrics)
estimated_metrics['ff'] = ff
candidates.append((model, parameters, estimated_metrics, satisfied_bounds))

if self.predictor:
Expand Down

0 comments on commit 7a67846

Please sign in to comment.