From 3cdb085afec606a48f22d95ac38fbed94c5b04c9 Mon Sep 17 00:00:00 2001 From: fred-labs Date: Wed, 14 Aug 2024 10:41:39 +0200 Subject: [PATCH] Add support for expressions (#157) --- docs/openscenario2.rst | 2 +- .../scenario_coverage/scenario_variation.py | 4 +- .../model/model_to_py_tree.py | 40 ++++- .../scenario_execution/model/types.py | 167 +++++++++++++++--- scenario_execution/test/test_expression.py | 87 ++++----- .../actions/ros_topic_check_data.py | 2 +- .../actions/ros_topic_monitor.py | 26 ++- .../scenario_execution_ros/lib_osc/ros.osc | 1 + .../test/test_topic_monitor.py | 81 +++++++++ .../actions/set_blackboard_var.py | 31 ++++ test/scenario_execution_test/setup.py | 1 + .../test/test_expression_with_var.py | 98 ++++++++++ 12 files changed, 462 insertions(+), 78 deletions(-) create mode 100644 test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py create mode 100644 test/scenario_execution_test/test/test_expression_with_var.py diff --git a/docs/openscenario2.rst b/docs/openscenario2.rst index cecf9e7c..310b4862 100644 --- a/docs/openscenario2.rst +++ b/docs/openscenario2.rst @@ -66,7 +66,7 @@ Element Tag Support Notes ``enum`` :raw-html:`✅` ``event`` :raw-html:`✅` ``every`` :raw-html:`❌` -``expression`` :raw-html:`❌` +``expression`` :raw-html:`✅` ``extend`` :raw-html:`❌` ``external`` :raw-html:`❌` ``fall`` :raw-html:`❌` diff --git a/scenario_coverage/scenario_coverage/scenario_variation.py b/scenario_coverage/scenario_coverage/scenario_variation.py index 7a03dd91..83c205a3 100644 --- a/scenario_coverage/scenario_coverage/scenario_variation.py +++ b/scenario_coverage/scenario_coverage/scenario_variation.py @@ -25,7 +25,7 @@ import py_trees from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.model.model_resolver import resolve_internal_model -from scenario_execution.model.types import RelationExpression, ListExpression, FieldAccessExpression, Expression, print_tree, serialize, to_string +from scenario_execution.model.types import RelationExpression, ListExpression, FieldAccessExpression, ModelExpression, print_tree, serialize, to_string from scenario_execution.utils.logging import Logger @@ -138,7 +138,7 @@ def save_resulting_scenarios(self, models): # create description variation_descriptions = [] for descr, entry in model[1]: - if isinstance(entry, Expression): + if isinstance(entry, ModelExpression): val = None for child in entry.get_children(): if not isinstance(child, FieldAccessExpression): diff --git a/scenario_execution/scenario_execution/model/model_to_py_tree.py b/scenario_execution/scenario_execution/model/model_to_py_tree.py index cbe394e8..f5a0074a 100644 --- a/scenario_execution/scenario_execution/model/model_to_py_tree.py +++ b/scenario_execution/scenario_execution/model/model_to_py_tree.py @@ -18,10 +18,9 @@ import py_trees from py_trees.common import Access, Status from pkg_resources import iter_entry_points - import inspect -from scenario_execution.model.types import ActionDeclaration, EventReference, FunctionApplicationExpression, ModifierInvocation, ScenarioDeclaration, DoMember, WaitDirective, EmitDirective, BehaviorInvocation, EventCondition, EventDeclaration, RelationExpression, LogicalExpression, ElapsedExpression, PhysicalLiteral, ModifierDeclaration +from scenario_execution.model.types import KeepConstraintDeclaration, visit_expression, ActionDeclaration, BinaryExpression, EventReference, Expression, FunctionApplicationExpression, ModifierInvocation, ScenarioDeclaration, DoMember, WaitDirective, EmitDirective, BehaviorInvocation, EventCondition, EventDeclaration, RelationExpression, LogicalExpression, ElapsedExpression, PhysicalLiteral, ModifierDeclaration from scenario_execution.model.model_base_visitor import ModelBaseVisitor from scenario_execution.model.error import OSC2ParsingError from scenario_execution.actions.base_action import BaseAction @@ -103,6 +102,20 @@ def update(self): return Status.SUCCESS +class ExpressionBehavior(py_trees.behaviour.Behaviour): + + def __init__(self, name: "ExpressionBehavior", expression: Expression): + super().__init__(name) + + self.expression = expression + + def update(self): + if self.expression.eval(): + return Status.SUCCESS + else: + return Status.RUNNING + + class ModelToPyTree(object): def __init__(self, logger): @@ -122,6 +135,7 @@ class BehaviorInit(ModelBaseVisitor): def __init__(self, logger, tree) -> None: super().__init__() self.logger = logger + self.blackboard = None if not isinstance(tree, py_trees.composites.Sequence): raise ValueError("ModelToPyTree requires a py-tree sequence as input") self.tree = tree @@ -348,19 +362,25 @@ def visit_event_reference(self, node: EventReference): def visit_event_condition(self, node: EventCondition): expression = "" for child in node.get_children(): - if isinstance(child, RelationExpression): - raise NotImplementedError() - elif isinstance(child, LogicalExpression): - raise NotImplementedError() + if isinstance(child, (RelationExpression, LogicalExpression)): + expression = ExpressionBehavior(name=node.get_ctx()[2], expression=self.visit(child)) elif isinstance(child, ElapsedExpression): elapsed_condition = self.visit_elapsed_expression(child) - expression = py_trees.timers.Timer( - name=f"wait {elapsed_condition}s", duration=float(elapsed_condition)) + expression = py_trees.timers.Timer(name=f"wait {elapsed_condition}s", duration=float(elapsed_condition)) else: raise OSC2ParsingError( msg=f'Invalid event condition {child}', context=node.get_ctx()) return expression + def visit_relation_expression(self, node: RelationExpression): + return visit_expression(node, self.blackboard) + + def visit_logical_expression(self, node: LogicalExpression): + return visit_expression(node, self.blackboard) + + def visit_binary_expression(self, node: BinaryExpression): + return visit_expression(node, self.blackboard) + def visit_elapsed_expression(self, node: ElapsedExpression): elem = node.find_first_child_of_type(PhysicalLiteral) if not elem: @@ -389,3 +409,7 @@ def visit_modifier_invocation(self, node: ModifierInvocation): self.create_decorator(node.modifier, resolved_values) except ValueError as e: raise OSC2ParsingError(msg=f'ModifierDeclaration {e}.', context=node.get_ctx()) from e + + def visit_keep_constraint_declaration(self, node: KeepConstraintDeclaration): + # skip relation-expression + pass diff --git a/scenario_execution/scenario_execution/model/types.py b/scenario_execution/scenario_execution/model/types.py index e630b1b7..9903e826 100644 --- a/scenario_execution/scenario_execution/model/types.py +++ b/scenario_execution/scenario_execution/model/types.py @@ -18,6 +18,7 @@ from scenario_execution.model.error import OSC2ParsingError import sys import py_trees +import operator as op def print_tree(elem, logger, whitespace=""): @@ -338,16 +339,21 @@ def get_value_child(self): return None for child in self.get_children(): - if isinstance(child, (StringLiteral, FloatLiteral, BoolLiteral, IntegerLiteral, FunctionApplicationExpression, IdentifierReference, PhysicalLiteral, EnumValueReference, ListExpression)): + if isinstance(child, (StringLiteral, FloatLiteral, BoolLiteral, IntegerLiteral, FunctionApplicationExpression, IdentifierReference, PhysicalLiteral, EnumValueReference, ListExpression, BinaryExpression, RelationExpression, LogicalExpression)): return child + elif isinstance(child, KeepConstraintDeclaration): + pass + elif not isinstance(child, Type): + raise OSC2ParsingError(msg=f'Parameter has invalid value "{type(child).__name__}".', context=self.get_ctx()) return None def get_resolved_value(self, blackboard=None): param_type, is_list = self.get_type() vals = {} params = {} - if self.get_value_child(): - vals = self.get_value_child().get_resolved_value(blackboard) + val_child = self.get_value_child() + if val_child: + vals = val_child.get_resolved_value(blackboard) if isinstance(param_type, StructuredDeclaration) and not is_list: params = param_type.get_resolved_value(blackboard) @@ -447,7 +453,7 @@ def get_type_string(self): return self.name -class Expression(ModelElement): +class ModelExpression(ModelElement): pass @@ -1549,7 +1555,7 @@ def get_base_type(self): return self.modifier -class RiseExpression(Expression): +class RiseExpression(ModelExpression): def __init__(self): super().__init__() @@ -1569,7 +1575,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class FallExpression(Expression): +class FallExpression(ModelExpression): def __init__(self): super().__init__() @@ -1589,7 +1595,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class ElapsedExpression(Expression): +class ElapsedExpression(ModelExpression): def __init__(self): super().__init__() @@ -1609,7 +1615,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class EveryExpression(Expression): +class EveryExpression(ModelExpression): def __init__(self): super().__init__() @@ -1629,7 +1635,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class SampleExpression(Expression): +class SampleExpression(ModelExpression): def __init__(self): super().__init__() @@ -1649,7 +1655,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class CastExpression(Expression): +class CastExpression(ModelExpression): def __init__(self, object_def, target_type): super().__init__() @@ -1671,7 +1677,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class TypeTestExpression(Expression): +class TypeTestExpression(ModelExpression): def __init__(self, object_def, target_type): super().__init__() @@ -1693,7 +1699,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class ElementAccessExpression(Expression): +class ElementAccessExpression(ModelExpression): def __init__(self, list_name, index): super().__init__() @@ -1715,7 +1721,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class FunctionApplicationExpression(Expression): +class FunctionApplicationExpression(ModelExpression): def __init__(self, func_name): super().__init__() @@ -1775,7 +1781,7 @@ def get_type_string(self): return self.get_type()[0].name -class FieldAccessExpression(Expression): +class FieldAccessExpression(ModelExpression): def __init__(self, field_name): super().__init__() @@ -1796,7 +1802,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class BinaryExpression(Expression): +class BinaryExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1816,8 +1822,24 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + type_string = None + for child in self.get_children(): + current = child.get_type_string() + if self.operator in ("/", "%", "*"): # multiplied by factor + if type_string is None or type_string in ("float", "int"): + type_string = current + else: + if type_string not in (current, type_string): + raise OSC2ParsingError(f'Children have different types {current}, {type_string}', context=self.get_ctx()) + type_string = current + return type_string + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() + -class UnaryExpression(Expression): +class UnaryExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1838,7 +1860,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class TernaryExpression(Expression): +class TernaryExpression(ModelExpression): def __init__(self): super().__init__() @@ -1858,7 +1880,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class LogicalExpression(Expression): +class LogicalExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1878,8 +1900,14 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + return "bool" + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() -class RelationExpression(Expression): + +class RelationExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1899,8 +1927,14 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + return "bool" + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() + -class ListExpression(Expression): +class ListExpression(ModelExpression): def __init__(self): super().__init__() @@ -1935,7 +1969,7 @@ def get_resolved_value(self, blackboard=None): return value -class RangeExpression(Expression): +class RangeExpression(ModelExpression): def __init__(self): super().__init__() @@ -2198,6 +2232,8 @@ def get_type_string(self): def get_blackboard_reference(self, blackboard): if not isinstance(self.ref, list) or len(self.ref) == 0: raise ValueError("Variable Reference only supported if reference is list with at least one element") + if not isinstance(self.ref[0], ParameterDeclaration): + raise ValueError("Variable Reference only supported if reference is part of a parameter declaration") fqn = self.ref[0].get_fully_qualified_var_name(include_scenario=False) if blackboard is None: raise ValueError("Variable Reference found, but no blackboard client available.") @@ -2206,6 +2242,12 @@ def get_blackboard_reference(self, blackboard): blackboard.register_key(fqn, access=py_trees.common.Access.WRITE) return VariableReference(blackboard, fqn) + def get_variable_reference(self, blackboard): + if isinstance(self.ref, list) and any(isinstance(x, VariableDeclaration) for x in self.ref): + return self.get_blackboard_reference(blackboard) + else: + return None + def get_resolved_value(self, blackboard=None): if isinstance(self.ref, list): ref = self.ref[0] @@ -2222,3 +2264,86 @@ def get_resolved_value(self, blackboard=None): return val else: return self.ref.get_resolved_value(blackboard) + + +class Expression(object): + def __init__(self, left, right, operator) -> None: + self.left = left + self.right = right + self.operator = operator + + def resolve(self, param): + if isinstance(param, Expression): + return param.eval() + elif isinstance(param, VariableReference): + return param.get_value() + else: + return param + + def eval(self): + left = self.resolve(self.left) + if self.right is None: + return self.operator(left) + else: + right = self.resolve(self.right) + return self.operator(left, right) + + +def visit_expression(node, blackboard): + operator = None + single_child = False + if node.operator == "==": + operator = op.eq + elif node.operator == "!=": + operator = op.ne + elif node.operator == "<": + operator = op.lt + elif node.operator == "<=": + operator = op.le + elif node.operator == ">": + operator = op.gt + elif node.operator == ">=": + operator = op.ge + elif node.operator == "and": + operator = op.and_ + elif node.operator == "or": + operator = op.or_ + elif node.operator == "not": + single_child = True + operator = op.not_ + elif node.operator == "+": + operator = op.add + elif node.operator == "-": + operator = op.sub + elif node.operator == "*": + operator = op.mul + elif node.operator == "/": + operator = op.truediv + elif node.operator == "%": + operator = op.mod + else: + raise NotImplementedError(f"Unknown expression operator '{node.operator}'.") + + if not single_child and node.get_child_count() != 2: + raise ValueError("Expression is expected to have two children.") + + idx = 0 + args = [None, None] + for child in node.get_children(): + if isinstance(child, (RelationExpression, BinaryExpression, LogicalExpression)): + args[idx] = visit_expression(child, blackboard) + else: + if isinstance(child, IdentifierReference): + var_def = child.get_variable_reference(blackboard) + if var_def is not None: + args[idx] = var_def + else: + args[idx] = child.get_resolved_value(blackboard) + else: + args[idx] = child.get_resolved_value(blackboard) + idx += 1 + + if single_child: + return Expression(args[0], args[1], operator) + else: + return Expression(args[0], args[1], operator) diff --git a/scenario_execution/test/test_expression.py b/scenario_execution/test/test_expression.py index ab1a71b0..070273de 100644 --- a/scenario_execution/test/test_expression.py +++ b/scenario_execution/test/test_expression.py @@ -19,6 +19,7 @@ from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.utils.logging import Logger from antlr4.InputStream import InputStream +import py_trees class TestExpression(unittest.TestCase): @@ -29,8 +30,12 @@ class TestExpression(unittest.TestCase): def setUp(self) -> None: self.parser = OpenScenario2Parser(Logger('test', False)) + self.tree = py_trees.composites.Sequence(name="", memory=True) + + def parse(self, scenario_content): + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + return self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) - @unittest.skip(reason="requires porting") def test_add(self): scenario_content = """ type time is SI(s: 1) @@ -41,14 +46,20 @@ def test_add(self): global test2: time = 2.0s + 1.1s global test3: time = 2.0s + 1ms """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 3.1) + self.assertAlmostEqual(model._ModelElement__children[4].get_resolved_value(), 3.1) + self.assertAlmostEqual(model._ModelElement__children[5].get_resolved_value(), 2.001) + + def test_add_different_types(self): + scenario_content = """ +type time is SI(s: 1) +unit s of time is SI(s: 1, factor: 1) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 3.1) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 3.1) - self.assertEqual(model._ModelElement__children[5].get_resolved_value(), 2.001) +global test2: time = 2.0s + 1.1 +""" + self.assertRaises(ValueError, self.parse, scenario_content) - @unittest.skip(reason="requires porting") def test_substract(self): scenario_content = """ type time is SI(s: 1) @@ -59,14 +70,20 @@ def test_substract(self): global test2: time = 2.0s - 1.1s global test3: time = 2.0s - 1ms """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.9) + self.assertAlmostEqual(model._ModelElement__children[4].get_resolved_value(), 0.9) + self.assertAlmostEqual(model._ModelElement__children[5].get_resolved_value(), 1.999) + + def test_substract_different_types(self): + scenario_content = """ +type time is SI(s: 1) +unit s of time is SI(s: 1, factor: 1) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 0.9) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.9) - self.assertEqual(model._ModelElement__children[5].get_resolved_value(), 1.999) +global test2: time = 2.0s - 1.1 +""" + self.assertRaises(ValueError, self.parse, scenario_content) - @unittest.skip(reason="requires porting") def test_multiply(self): scenario_content = """ type time is SI(s: 1) @@ -75,13 +92,10 @@ def test_multiply(self): global test1: float = 2.0 * 1.1 global test2: time = 2.0ms * 1.1 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[2].get_resolved_value(), 2.2) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.0022) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 2.2) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.0022) - - @unittest.skip(reason="requires porting") def test_divide(self): scenario_content = """ type time is SI(s: 1) @@ -90,13 +104,10 @@ def test_divide(self): global test1: float = 5.0 / 2.0 global test2: time = 5.0ms / 2.0 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 2.5) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.0025) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[2].get_resolved_value(), 2.5) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.0025) - @unittest.skip(reason="requires porting") def test_relation(self): scenario_content = """ type time is SI(s: 1) @@ -110,37 +121,31 @@ def test_relation(self): global test6: bool = 5 >= 2 global test7: bool = 5 <= 2 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) + self.assertEqual(model._ModelElement__children[2].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[3].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), True) + self.assertEqual(model._ModelElement__children[4].get_resolved_value(), False) self.assertEqual(model._ModelElement__children[5].get_resolved_value(), False) - self.assertEqual(model._ModelElement__children[6].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[6].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[7].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[8].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[9].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[8].get_resolved_value(), False) - @unittest.skip(reason="requires porting") def test_negation(self): scenario_content = """ -global test1: bool = not True +global test1: bool = not true global test1: bool = not 5 > 2 +global test1: bool = not false """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) self.assertEqual(model._ModelElement__children[0].get_resolved_value(), False) self.assertEqual(model._ModelElement__children[1].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[2].get_resolved_value(), True) - @unittest.skip(reason="requires porting") def test_compound_expression(self): scenario_content = """ global test1: bool = 2 > 1 and 3 >= 2 global test1: bool = 2 > 1 or 3 < 2 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) self.assertEqual(model._ModelElement__children[0].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[1].get_resolved_value(), True) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py index 5c07db2f..d1c1255c 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py @@ -125,7 +125,7 @@ def check_data(self, msg): try: value = check_attr(msg) except AttributeError: - self.feedback_message = "Member name not found {self.member_name}]" + self.feedback_message = f"Member name not found {self.member_name}" self.found = self.comparison_operator(value, self.expected_value) def set_expected_value(self, expected_value_string): diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py index a07423a3..1279c0d5 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py @@ -19,13 +19,15 @@ from scenario_execution.model.types import VariableReference import rclpy import py_trees +import operator class RosTopicMonitor(BaseAction): - def __init__(self, topic_name: str, topic_type: str, target_variable: object, qos_profile: tuple): + def __init__(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): super().__init__(resolve_variable_reference_arguments_in_execute=False) self.target_variable = None + self.member_name = member_name self.topic_type = topic_type self.qos_profile = qos_profile self.topic_name = topic_name @@ -43,8 +45,13 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise KeyError(error_message) from e + msg_type = get_ros_message_type(self.topic_type) + + # check if member-name exists + self.get_value(msg_type()) + self.subscriber = self.node.create_subscription( - msg_type=get_ros_message_type(self.topic_type), + msg_type=msg_type, topic=self.topic_name, callback=self._callback, qos_profile=get_qos_preset_profile(self.qos_profile), @@ -52,16 +59,27 @@ def setup(self, **kwargs): ) self.feedback_message = f"Monitoring data on {self.topic_name}" # pylint: disable= attribute-defined-outside-init - def execute(self, topic_name, topic_type, target_variable, qos_profile): + def execute(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: raise ValueError("Updating topic parameters not supported.") if not isinstance(target_variable, VariableReference): raise ValueError(f"'target_variable' is expected to be a variable reference.") self.target_variable = target_variable + self.member_name = member_name def update(self) -> py_trees.common.Status: return py_trees.common.Status.SUCCESS def _callback(self, msg): if self.target_variable is not None: - self.target_variable.set_value(msg) + self.target_variable.set_value(self.get_value(msg)) + + def get_value(self, msg): + if self.member_name != "": + check_attr = operator.attrgetter(self.member_name) + try: + return check_attr(msg) + except AttributeError as e: + raise ValueError(f"invalid member_name '{self.member_name}'") from e + else: + return msg diff --git a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc index 776c2463..548c6a79 100644 --- a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc +++ b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc @@ -128,6 +128,7 @@ action topic_monitor: topic_type: string # class of the message type (e.g. std_msgs.msg.String) target_variable: string # name of the variable (e.g. a 'var' within an actor instance) qos_profile: qos_preset_profiles = qos_preset_profiles!system_default # qos profile for the subscriber + member_name: string = "" # if not empty, only the value of the member is stored within the variable action topic_publish: # publish a message on a topic diff --git a/test/scenario_execution_ros_test/test/test_topic_monitor.py b/test/scenario_execution_ros_test/test/test_topic_monitor.py index 1d62138a..96ecd372 100644 --- a/test/scenario_execution_ros_test/test/test_topic_monitor.py +++ b/test/scenario_execution_ros_test/test/test_topic_monitor.py @@ -90,3 +90,84 @@ def test_success(self): with open(self.tmp_file.name) as f: result = f.read() self.assertEqual(result, "std_msgs.msg.String(data='TEST')") + + def test_member_success(self): + scenario_content = """ +import osc.ros + +action store_action: + file_path: string + value: string + +actor test_actor: + var test: string = "one" + +scenario test_scenario: + foo: test_actor + + do parallel: + serial: + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.String", '{\\\"data\\\": \\\"TEST\\\"}') + serial: + topic_monitor("/bla", "std_msgs.msg.String", foo.test, member_name: "data") + wait elapsed(2s) + store_action('""" + self.tmp_file.name + """', foo.test) +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) + with open(self.tmp_file.name) as f: + result = f.read() + self.assertEqual(result, "TEST") + + def test_member_unknown(self): + scenario_content = """ +import osc.ros + +action store_action: + file_path: string + value: string + +actor test_actor: + var test: string = "one" + +scenario test_scenario: + foo: test_actor + + do parallel: + serial: + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.String", '{\\\"data\\\": \\\"TEST\\\"}') + serial: + topic_monitor("/bla", "std_msgs.msg.String", foo.test, member_name: "UNKNOWN") + wait elapsed(2s) + store_action('""" + self.tmp_file.name + """', foo.test) +""" + self.execute(scenario_content) + self.assertFalse(self.scenario_execution_ros.process_results()) + + def test_member_relation_expr_success(self): + scenario_content = """ +import osc.ros +import osc.helpers + +struct current_state: + var message_count: int = 1 + +scenario test_scenario: + timeout(10s) + current: current_state + do serial: + parallel: + serial: + repeat() + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.Int64", '{\\\"data\\\": 2}') + topic_monitor("/bla", "std_msgs.msg.Int64", member_name: "data", target_variable: current.message_count) + + serial: + wait current.message_count == 2 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) diff --git a/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py b/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py new file mode 100644 index 00000000..e8b003a6 --- /dev/null +++ b/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import py_trees +from py_trees.common import Status +from scenario_execution.actions.base_action import BaseAction + + +class SetBlackboardVariable(BaseAction): + + def execute(self, variable_name: str, variable_value): + self.variable_name = variable_name + self.variable_value = variable_value + self.get_blackboard_client().register_key(self.variable_name, access=py_trees.common.Access.WRITE) + + def update(self) -> py_trees.common.Status: + self.get_blackboard_client().set(self.variable_name, self.variable_value) + return Status.SUCCESS diff --git a/test/scenario_execution_test/setup.py b/test/scenario_execution_test/setup.py index 9acd74d4..24524562 100644 --- a/test/scenario_execution_test/setup.py +++ b/test/scenario_execution_test/setup.py @@ -42,5 +42,6 @@ 'scenario_execution.actions': [ 'test_actor.set_value = scenario_execution_test.actions.actor_set_value:ActorSetValue', 'store_action = scenario_execution_test.actions.store_action:StoreAction', + 'set_blackboard_var = scenario_execution_test.actions.set_blackboard_var:SetBlackboardVariable', ]} ) diff --git a/test/scenario_execution_test/test/test_expression_with_var.py b/test/scenario_execution_test/test/test_expression_with_var.py new file mode 100644 index 00000000..4e358aba --- /dev/null +++ b/test/scenario_execution_test/test/test_expression_with_var.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import tempfile +import py_trees +from scenario_execution import ScenarioExecution +from scenario_execution.model.osc2_parser import OpenScenario2Parser +from scenario_execution.model.model_to_py_tree import create_py_tree +from scenario_execution.model.model_blackboard import create_py_tree_blackboard +from scenario_execution.utils.logging import Logger + +from antlr4.InputStream import InputStream + + +class TestCheckData(unittest.TestCase): + # pylint: disable=missing-function-docstring,missing-class-docstring + + def setUp(self) -> None: + self.parser = OpenScenario2Parser(Logger('test', False)) + self.scenario_execution = ScenarioExecution(debug=False, log_model=False, live_tree=False, + scenario_file="test.osc", output_dir=None) + self.tree = py_trees.composites.Sequence(name="", memory=True) + self.tmp_file = tempfile.NamedTemporaryFile() + + def execute(self, scenario_content): + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) + create_py_tree_blackboard(model, self.tree, self.parser.logger, False) + self.tree = create_py_tree(model, self.tree, self.parser.logger, False) + self.scenario_execution.tree = self.tree + self.scenario_execution.run() + + def test_success(self): + scenario_content = """ +import osc.helpers + +struct current_state: + var val: int = 1 + +action set_blackboard_var: + variable_name: string + variable_value: string + +scenario test_scenario: + timeout(5s) + current: current_state + do parallel: + serial: + wait elapsed(0.2s) + set_blackboard_var("current/val", 2) + wait elapsed(10s) + serial: + wait current.val * 2 + 4 - 4 / 2 == 6 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution.process_results()) + + def test_success_not(self): + scenario_content = """ +import osc.helpers + +struct current_state: + var val: bool = false + var val2: bool = false + +action set_blackboard_var: + variable_name: string + variable_value: string + +scenario test_scenario: + timeout(5s) + current: current_state + do parallel: + serial: + wait elapsed(0.2s) + set_blackboard_var("current/val", true) + wait elapsed(0.2s) + serial: + wait current.val and not current.val2 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution.process_results())