Skip to content

Commit

Permalink
Merge branch 'f/leq/geq-directions-to-random_walk-constraint-solver' …
Browse files Browse the repository at this point in the history
…into 'main'

Add leq and geq to right_direction in random_walk constrainer

See merge request es/ai/hannah/hannah!410
  • Loading branch information
moreib committed Oct 9, 2024
2 parents 39ce0ac + 6a04c28 commit 1338383
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
10 changes: 10 additions & 0 deletions hannah/nas/constraints/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def right_direction(self, current, new, direction):
return True
else:
return False
elif direction == ">=":
if new > current:
return True
else:
return False
elif direction == "<=":
if new < current:
return True
else:
return False

def solve(self, module, parameters, fix_vars=[]):
print("Start constraint solving")
Expand Down
22 changes: 21 additions & 1 deletion hannah/nas/test/test_random_walk_constrainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hannah.nas.functional_operators.op import Tensor, scope, search_space
from hannah.nas.constraints.random_walk import get_active_parameter
from hannah.nas.constraints.random_walk import get_active_parameter, RandomWalkConstraintSolver
from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv_relu, dynamic_depth, linear
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter

Expand Down Expand Up @@ -58,6 +58,26 @@ def test_get_active_params():
assert len(active_params) == 13


def test_right_direction():
rw_solver = RandomWalkConstraintSolver()
dir = "<"
assert not rw_solver.right_direction(10, 15, dir)
assert rw_solver.right_direction(26, 17, dir)

dir = ">"
assert rw_solver.right_direction(10, 15, dir)
assert not rw_solver.right_direction(26, 17, dir)

dir = "<="
assert not rw_solver.right_direction(10, 15, dir)
assert rw_solver.right_direction(26, 17, dir)

dir = ">="
assert rw_solver.right_direction(10, 15, dir)
assert not rw_solver.right_direction(26, 17, dir)


if __name__ == '__main__':
test_get_active_params()
test_right_direction()

0 comments on commit 1338383

Please sign in to comment.