Skip to content

Commit

Permalink
Merge branch 'fix/mutator_steps' into 'main'
Browse files Browse the repository at this point in the history
Fix mutator steps

See merge request es/ai/hannah/hannah!409
  • Loading branch information
moreib committed Oct 2, 2024
2 parents b88cc74 + cbe08ea commit 39ce0ac
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
26 changes: 16 additions & 10 deletions hannah/nas/search/sampler/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def mutate_choice(self, parameter):
return chosen_mutation(parameter)

def mutate_int_scalar(self, parameter):
mutations = self.get_int_mutations()
mutations = self.get_int_mutations(parameter)
chosen_mutation = self.rng.choice(mutations)
return int(chosen_mutation(parameter))

Expand All @@ -62,8 +62,13 @@ def mutate_generic(self, parameter):
def get_choice_mutations(self):
return [self.random_choice, self.increase_choice, self.decrease_choice]

def get_int_mutations(self):
return [self.random_int_scalar, self.increase_int_scalar, self.decrease_int_scalar]
def get_int_mutations(self, parameter):
possible_mutations = [self.random_int_scalar]
if parameter.current_value + parameter.step_size <= parameter.max:
possible_mutations.append(self.increase_int_scalar)
if parameter.current_value - parameter.step_size >= parameter.min:
possible_mutations.append(self.decrease_int_scalar)
return possible_mutations

def get_float_mutations(self):
return [self.random_float_scalar]
Expand All @@ -87,23 +92,24 @@ def decrease_choice(self, parameter):
return parameter.choices[index - 1]

def random_int_scalar(self, parameter):
return parameter.rng.integers(parameter.min, parameter.max+1)
return int(parameter.rng.choice(range(parameter.min, parameter.max+1, parameter.step_size)))

def increase_int_scalar(self, parameter):
if parameter.current_value < parameter.max:
return parameter.current_value + 1
if parameter.current_value + parameter.step_size < parameter.max:
return parameter.current_value + parameter.step_size
else:
return parameter.rng.integers(parameter.min, parameter.max+1)
return parameter.current_value

def decrease_int_scalar(self, parameter):
if parameter.current_value > parameter.min:
return parameter.current_value - 1
if parameter.current_value - parameter.step_size > parameter.min:
return parameter.current_value - parameter.step_size
else:
return parameter.rng.integers(parameter.min, parameter.max+1)
return parameter.current_value

def random_float_scalar(self, parameter):
return parameter.rng.uniform(parameter.min, parameter.max)


if __name__ == '__main__':
mutator = ParameterMutator(0.1)
par = FloatScalarParameter(2, 5)
Expand Down
39 changes: 39 additions & 0 deletions hannah/nas/test/test_mutator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from hannah.nas.parameters.parameters import IntScalarParameter
from hannah.nas.search.sampler.mutator import ParameterMutator


def test_int_mutations():
param = IntScalarParameter(min=4, max=32, step_size=4)
mutator = ParameterMutator(0.1)

assert len(mutator.get_int_mutations(param)) == 2
param.set_current(16)
assert len(mutator.get_int_mutations(param)) == 3
param.set_current(32)
assert len(mutator.get_int_mutations(param)) == 2

val = mutator.decrease_int_scalar(param)
assert val == 28
param.set_current(8)
val = mutator.increase_int_scalar(param)
assert val == 12

val = mutator.random_int_scalar(param)
assert val >= param.min
assert val <= param.max
assert val % param.step_size == 0

param.set_current(32)
val = mutator.increase_int_scalar(param)
assert val == 32

param.set_current(16)
val = mutator.mutate_int_scalar(param)
assert isinstance(val, int)
assert val >= param.min
assert val <= param.max
assert val % param.step_size == 0


if __name__ == "__main__":
test_int_mutations()

0 comments on commit 39ce0ac

Please sign in to comment.