diff --git a/docs/development.rst b/docs/development.rst index 2bbf9eeb..8bbc65fb 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -95,8 +95,9 @@ Implement an Action - Make use of ``kwargs['logger']``, available in ``setup()`` - If you want to draw markers for RViz, use ``kwargs['marker_handler']``, available in ``setup()`` (with ROS backend) - Use arguments from ``__init__()`` for a longer running initialization in ``setup()`` and the arguments from ``execute()`` to set values just before executing the action. - - ``__init__()`` does not need to contain all osc2-defined arguments. This can be convenient as variable argument resolving might not be available during ``__init__()``. - - ``execute()`` contains all osc2-arguments. +- ``__init__()`` and ``setup()`` are called once, ``execute()`` might be called multiple times. +- osc2 arguments can only be consumed once, either in ``__init__()`` or ``execute()``. Exception: If an ``associated_actor`` exists, it's an argument of both methods. +- Arguments, that need late resolving (e.g. refering to variables or extenral methods), specify in ``execute()``. - ``setup()`` provides several arguments that might be useful: - ``input_dir``: Directory containing the scenario file - ``output_dir``: If given on command-line, contains the directory to save output to diff --git a/examples/example_library/example_library/actions/custom_action.py b/examples/example_library/example_library/actions/custom_action.py index 639a73f2..c7a2d6f3 100644 --- a/examples/example_library/example_library/actions/custom_action.py +++ b/examples/example_library/example_library/actions/custom_action.py @@ -21,7 +21,7 @@ class CustomAction(BaseAction): - def __init__(self, data: str): # get action arguments, at the time of initialization + def __init__(self): # get action arguments, at the time of initialization super().__init__() def execute(self, data: str): # get action arguments, at the time of execution (may got updated during scenario execution) diff --git a/libs/scenario_execution_floorplan_dsl/scenario_execution_floorplan_dsl/actions/generate_gazebo_world.py b/libs/scenario_execution_floorplan_dsl/scenario_execution_floorplan_dsl/actions/generate_gazebo_world.py index c8dd7d24..1cee1c84 100644 --- a/libs/scenario_execution_floorplan_dsl/scenario_execution_floorplan_dsl/actions/generate_gazebo_world.py +++ b/libs/scenario_execution_floorplan_dsl/scenario_execution_floorplan_dsl/actions/generate_gazebo_world.py @@ -24,7 +24,7 @@ class GenerateGazeboWorld(BaseAction): - def __init__(self, associated_actor, sdf_template: str, arguments: list): + def __init__(self, associated_actor, sdf_template: str): super().__init__() self.sdf_template = sdf_template self.spawn_utils = SpawnUtils(self.logger) @@ -42,7 +42,7 @@ def setup(self, **kwargs): raise ActionError(f"SDF Template {self.sdf_template} not found.", action=self) self.tmp_file = tempfile.NamedTemporaryFile(suffix=".sdf") # for testing, do not delete temp file: delete=False - def execute(self, associated_actor, sdf_template: str, arguments: list): + def execute(self, associated_actor, arguments: list): self.arguments_string = "" for elem in arguments: self.arguments_string += f'{elem["key"]}:={elem["value"]}' diff --git a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_actor_exists.py b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_actor_exists.py index e905eda4..1897ba56 100644 --- a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_actor_exists.py +++ b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_actor_exists.py @@ -37,8 +37,9 @@ class GazeboActorExists(RunProcess): """ - def __init__(self, entity_name: str, world_name: str): + def __init__(self): super().__init__() + self.entity_name = None self.current_state = ActorExistsActionState.IDLE def execute(self, entity_name: str, world_name: str): # pylint: disable=arguments-differ diff --git a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_delete_actor.py b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_delete_actor.py index c86b1105..7237849f 100644 --- a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_delete_actor.py +++ b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_delete_actor.py @@ -37,8 +37,9 @@ class GazeboDeleteActor(RunProcess): """ - def __init__(self, associated_actor, entity_name: str, world_name: str): + def __init__(self, associated_actor): super().__init__() + self.entity_name = None self.current_state = DeleteActionState.IDLE def execute(self, associated_actor, entity_name: str, world_name: str): # pylint: disable=arguments-differ diff --git a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_relative_spawn_actor.py b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_relative_spawn_actor.py index 22ab2e8d..c31419a7 100644 --- a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_relative_spawn_actor.py +++ b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_relative_spawn_actor.py @@ -30,20 +30,17 @@ class GazeboRelativeSpawnActor(GazeboSpawnActor): """ - def __init__(self, associated_actor, - frame_id: str, parent_frame_id: str, - distance: float, world_name: str, xacro_arguments: list, - model: str): - super().__init__(associated_actor, None, world_name, xacro_arguments, model) + def __init__(self, associated_actor, xacro_arguments: list, model: str): + super().__init__(associated_actor, xacro_arguments, model) self._pose = '{}' + self.model = model + self.world_name = None + self.xacro_arguments = xacro_arguments self.tf_buffer = Buffer() self.tf_listener = None - def execute(self, associated_actor, # pylint: disable=arguments-differ - frame_id: str, parent_frame_id: str, - distance: float, world_name: str, xacro_arguments: list, - model: str): - super().execute(associated_actor, None, world_name, xacro_arguments, model) + def execute(self, associated_actor, frame_id: str, parent_frame_id: str, distance: float, world_name: str): # pylint: disable=arguments-differ + super().execute(associated_actor, None, world_name) self.frame_id = frame_id self.parent_frame_id = parent_frame_id self.distance = distance diff --git a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_spawn_actor.py b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_spawn_actor.py index 7ea6b56d..a98f4f9a 100644 --- a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_spawn_actor.py +++ b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_spawn_actor.py @@ -45,7 +45,7 @@ class GazeboSpawnActor(RunProcess): """ - def __init__(self, associated_actor, spawn_pose: list, world_name: str, xacro_arguments: list, model: str): + def __init__(self, associated_actor, xacro_arguments: list, model: str): """ init """ @@ -92,9 +92,7 @@ def setup(self, **kwargs): raise ActionError(f'Invalid model specified ({self.entity_model})', action=self) self.current_state = SpawnActionState.MODEL_AVAILABLE - def execute(self, associated_actor, spawn_pose: list, world_name: str, xacro_arguments: list, model: str): # pylint: disable=arguments-differ - if self.entity_model != model or set(self.xacro_arguments) != set(xacro_arguments): - raise ActionError("Runtime change of model not supported.", action=self) + def execute(self, associated_actor, spawn_pose: list, world_name: str): # pylint: disable=arguments-differ self.spawn_pose = spawn_pose self.world_name = world_name diff --git a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_wait_for_sim.py b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_wait_for_sim.py index 82e2d0a8..c9d1a584 100644 --- a/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_wait_for_sim.py +++ b/libs/scenario_execution_gazebo/scenario_execution_gazebo/actions/gazebo_wait_for_sim.py @@ -35,7 +35,7 @@ class GazeboWaitForSim(RunProcess): Class to wait for the simulation to become active """ - def __init__(self, world_name: str, timeout: int): + def __init__(self): super().__init__() self.current_state = WaitForSimulationActionState.IDLE diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py index 6c45509d..622c07fa 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py @@ -29,9 +29,9 @@ class KubernetesBaseActionState(Enum): class KubernetesBaseAction(BaseAction): - def __init__(self, namespace: str, within_cluster: bool): + def __init__(self, within_cluster: bool): super().__init__() - self.namespace = namespace + self.namespace = None self.within_cluster = within_cluster self.client = None self.current_state = KubernetesBaseActionState.IDLE @@ -44,10 +44,8 @@ def setup(self, **kwargs): config.load_kube_config() self.client = client.CoreV1Api() - def execute(self, namespace: str, within_cluster: bool): + def execute(self, namespace: str): self.namespace = namespace - if within_cluster != self.within_cluster: - raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.") def update(self) -> py_trees.common.Status: # pylint: disable=too-many-return-statements if self.current_state == KubernetesBaseActionState.IDLE: diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py index 33ee2a84..c96db161 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py @@ -20,13 +20,14 @@ class KubernetesPatchPod(KubernetesBaseAction): - def __init__(self, namespace: str, target: str, body: str, within_cluster: bool): - super().__init__(namespace, within_cluster) - self.target = target + def __init__(self, within_cluster: bool): + super().__init__(within_cluster) + self.namespace = None + self.target = None self.body = None - def execute(self, namespace: str, target: str, body: str, within_cluster: bool): # pylint: disable=arguments-differ - super().execute(namespace, within_cluster) + def execute(self, namespace: str, target: str, body: str): # pylint: disable=arguments-differ + super().execute(namespace) self.target = target trimmed_data = body.encode('utf-8').decode('unicode_escape') try: diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py index 8df6f3ae..a834590e 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py @@ -33,12 +33,12 @@ class KubernetesPodExecState(Enum): class KubernetesPodExec(BaseAction): - def __init__(self, target: str, command: list, regex: bool, namespace: str, within_cluster: bool): + def __init__(self, within_cluster: bool): super().__init__() - self.target = target - self.namespace = namespace - self.regex = regex - self.command = command + self.target = None + self.namespace = None + self.regex = None + self.command = None self.within_cluster = within_cluster self.client = None self.reponse_queue = queue.Queue() @@ -56,9 +56,7 @@ def setup(self, **kwargs): self.exec_thread = threading.Thread(target=self.pod_exec, daemon=True) - def execute(self, target: str, command: list, regex: bool, namespace: str, within_cluster: bool): - if within_cluster != self.within_cluster: - raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.") + def execute(self, target: str, command: list, regex: bool, namespace: str): self.target = target self.namespace = namespace self.command = command diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_network_policy_status.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_network_policy_status.py index c4d26ab8..0b9c3486 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_network_policy_status.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_network_policy_status.py @@ -43,7 +43,7 @@ def setup(self, **kwargs): self.k8s_client = client.api_client.ApiClient() self.network_client = client.NetworkingV1Api(self.k8s_client) - def execute(self, target: str, status: tuple, namespace: str, within_cluster: bool): + def execute(self): self.monitoring_thread = threading.Thread(target=self.watch_network, daemon=True) self.monitoring_thread.start() diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_pod_status.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_pod_status.py index 5551ea73..9a51e51c 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_pod_status.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_wait_for_pod_status.py @@ -32,15 +32,13 @@ class KubernetesWaitForPodStatusState(Enum): class KubernetesWaitForPodStatus(BaseAction): - def __init__(self, target: str, regex: bool, status: tuple, namespace: str, within_cluster: bool): + def __init__(self, within_cluster: bool): super().__init__() - self.target = target - self.namespace = namespace - if not isinstance(status, tuple) or not isinstance(status[0], str): - raise ValueError("Status expected to be enum.") - self.expected_status = status[0] + self.target = None + self.namespace = None + self.expected_status = None self.within_cluster = within_cluster - self.regex = regex + self.regex = None self.client = None self.update_queue = queue.Queue() self.current_state = KubernetesWaitForPodStatusState.IDLE @@ -55,11 +53,12 @@ def setup(self, **kwargs): self.monitoring_thread = threading.Thread(target=self.watch_pods, daemon=True) self.monitoring_thread.start() - def execute(self, target: str, regex: bool, status: tuple, namespace: str, within_cluster: bool): + def execute(self, target: str, regex: bool, status: tuple, namespace: str): self.target = target self.namespace = namespace + if not isinstance(status, tuple) or not isinstance(status[0], str): + raise ValueError("Status expected to be enum.") self.expected_status = status[0] - self.within_cluster = within_cluster self.regex = regex self.current_state = KubernetesWaitForPodStatusState.MONITORING diff --git a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/init_nav2.py b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/init_nav2.py index 1bbeac9d..26674f4d 100644 --- a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/init_nav2.py +++ b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/init_nav2.py @@ -58,12 +58,12 @@ class InitNav2(BaseAction): """ - def __init__(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool, namespace_override: str): + def __init__(self, associated_actor, namespace_override: str): super().__init__() - self.initial_pose = initial_pose - self.base_frame_id = base_frame_id - self.wait_for_initial_pose = wait_for_initial_pose - self.use_initial_pose = use_initial_pose + self.initial_pose = None + self.base_frame_id = None + self.wait_for_initial_pose = None + self.use_initial_pose = None self.namespace = associated_actor["namespace"] self.node = None self.future = None @@ -118,14 +118,12 @@ def setup(self, **kwargs): amcl_pose_qos, callback_group=ReentrantCallbackGroup()) - def execute(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool, namespace_override: str): + def execute(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool): self.initial_pose = initial_pose self.base_frame_id = base_frame_id self.wait_for_initial_pose = wait_for_initial_pose self.use_initial_pose = use_initial_pose self.namespace = associated_actor["namespace"] - if namespace_override: - self.namespace = namespace_override def update(self) -> py_trees.common.Status: """ diff --git a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_through_poses.py b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_through_poses.py index 8312006f..64989c30 100644 --- a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_through_poses.py +++ b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_through_poses.py @@ -25,19 +25,16 @@ class NavThroughPoses(RosActionCall): Class to navigate through poses """ - def __init__(self, associated_actor, goal_poses: list, action_topic: str, namespace_override: str): + def __init__(self, associated_actor, action_topic: str, namespace_override: str): self.namespace = associated_actor["namespace"] if namespace_override: self.namespace = namespace_override self.goal_poses = None - super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses", "") + super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses") - def execute(self, associated_actor, goal_poses: list, action_topic: str, namespace_override: str) -> None: # pylint: disable=arguments-differ,arguments-renamed - self.namespace = associated_actor["namespace"] - if namespace_override: - self.namespace = namespace_override + def execute(self, associated_actor, goal_poses: list) -> None: # pylint: disable=arguments-differ,arguments-renamed self.goal_poses = goal_poses - super().execute(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses", "") + super().execute("") def get_goal_msg(self): goal_msg = NavigateThroughPoses.Goal() diff --git a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_to_pose.py b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_to_pose.py index 755198f2..c79f9cf8 100644 --- a/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_to_pose.py +++ b/libs/scenario_execution_nav2/scenario_execution_nav2/actions/nav_to_pose.py @@ -26,19 +26,16 @@ class NavToPose(RosActionCall): Class to navigate to a pose """ - def __init__(self, associated_actor, goal_pose: list, action_topic: str, namespace_override: str) -> None: + def __init__(self, associated_actor, action_topic: str, namespace_override: str) -> None: self.namespace = associated_actor["namespace"] if namespace_override: self.namespace = namespace_override self.goal_pose = None - super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose", "") + super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose") - def execute(self, associated_actor, goal_pose: list, action_topic: str, namespace_override: str) -> None: # pylint: disable=arguments-differ,arguments-renamed - self.namespace = associated_actor["namespace"] - if namespace_override: - self.namespace = namespace_override + def execute(self, associated_actor, goal_pose: list) -> None: # pylint: disable=arguments-differ,arguments-renamed self.goal_pose = goal_pose - super().execute(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose", "") + super().execute("") def get_goal_msg(self): goal_msg = NavigateToPose.Goal() diff --git a/libs/scenario_execution_x11/scenario_execution_x11/actions/capture_screen.py b/libs/scenario_execution_x11/scenario_execution_x11/actions/capture_screen.py index c28ab8eb..92bf8994 100644 --- a/libs/scenario_execution_x11/scenario_execution_x11/actions/capture_screen.py +++ b/libs/scenario_execution_x11/scenario_execution_x11/actions/capture_screen.py @@ -30,8 +30,8 @@ class CaptureScreenState(Enum): class CaptureScreen(RunProcess): - def __init__(self, output_filename: str, frame_rate: float): - super().__init__("", wait_for_shutdown=True) + def __init__(self): + super().__init__() self.current_state = None self.output_dir = "." @@ -46,6 +46,7 @@ def setup(self, **kwargs): self.output_dir = kwargs['output_dir'] def execute(self, output_filename: str, frame_rate: float): # pylint: disable=arguments-differ + super().execute(None, wait_for_shutdown=True) self.current_state = CaptureScreenState.IDLE cmd = ["ffmpeg", "-f", "x11grab", diff --git a/scenario_execution/scenario_execution/actions/base_action.py b/scenario_execution/scenario_execution/actions/base_action.py index 3e5d8893..164b68e7 100644 --- a/scenario_execution/scenario_execution/actions/base_action.py +++ b/scenario_execution/scenario_execution/actions/base_action.py @@ -17,11 +17,12 @@ import py_trees from scenario_execution.model.types import ParameterDeclaration, ScenarioDeclaration from scenario_execution.model.error import OSC2Error +import inspect class BaseAction(py_trees.behaviour.Behaviour): - # subclasses might implement __init__() with the same arguments as defined in osc + # subclasses might implement __init__() with osc2 arguments as required # CAUTION: __init__() only gets the initial parameter values. In case variables get modified during scenario execution, # the latest values are available in execute() only. def __init__(self, resolve_variable_reference_arguments_in_execute=True): @@ -29,9 +30,16 @@ def __init__(self, resolve_variable_reference_arguments_in_execute=True): self.logger = None self.blackboard = None self.resolve_variable_reference_arguments_in_execute = resolve_variable_reference_arguments_in_execute + + execute_method = getattr(self, "execute", None) + if execute_method is not None and callable(execute_method): + self.execute_method = execute_method + self.execute_skip_args = inspect.getfullargspec(getattr(self, "__init__", None)).args + else: + self.execute_method = None super().__init__(self.__class__.__name__) - # Subclasses might implement execute() with the same arguments as defined in osc. + # Subclasses might implement execute() with the osc2 arguments that are not used within __init__(). # def execute(self): # Subclasses might override shutdown() in order to cleanup on scenario shutdown. @@ -43,13 +51,12 @@ def shutdown(self): ############# def initialise(self): - execute_method = getattr(self, "execute", None) - if execute_method is not None and callable(execute_method): - + if self.execute_method is not None: if self.resolve_variable_reference_arguments_in_execute: - final_args = self._model.get_resolved_value(self.get_blackboard_client()) + final_args = self._model.get_resolved_value(self.get_blackboard_client(), skip_keys=self.execute_skip_args) else: - final_args = self._model.get_resolved_value_with_variable_references(self.get_blackboard_client()) + final_args = self._model.get_resolved_value_with_variable_references( + self.get_blackboard_client(), skip_keys=self.execute_skip_args) if self._model.actor: final_args["associated_actor"] = self._model.actor.get_resolved_value(self.get_blackboard_client()) diff --git a/scenario_execution/scenario_execution/actions/run_process.py b/scenario_execution/scenario_execution/actions/run_process.py index e0f4d8ef..9e753471 100644 --- a/scenario_execution/scenario_execution/actions/run_process.py +++ b/scenario_execution/scenario_execution/actions/run_process.py @@ -27,12 +27,12 @@ class RunProcess(BaseAction): Class to execute an process. """ - def __init__(self, command=None, wait_for_shutdown=True, shutdown_timeout=10, shutdown_signal=("", signal.SIGTERM)): + def __init__(self): super().__init__() - self.command = command.split(" ") if isinstance(command, str) else command - self.wait_for_shutdown = wait_for_shutdown - self.shutdown_timeout = shutdown_timeout - self.shutdown_signal = shutdown_signal[1] + self.command = None + self.wait_for_shutdown = None + self.shutdown_timeout = None + self.shutdown_signal = None self.executed = False self.process = None self.log_stdout_thread = None diff --git a/scenario_execution/scenario_execution/model/model_to_py_tree.py b/scenario_execution/scenario_execution/model/model_to_py_tree.py index cecda3b7..f41a30e5 100644 --- a/scenario_execution/scenario_execution/model/model_to_py_tree.py +++ b/scenario_execution/scenario_execution/model/model_to_py_tree.py @@ -200,21 +200,14 @@ def compare_method_arguments(self, method, expected_args, behavior_name, node): raise OSC2ParsingError( msg=f'Plugin {behavior_name} {method.__name__} method is missing argument "self".', context=node.get_ctx()) - unknown_args = [] + unexpected_args = [] missing_args = copy.copy(expected_args) for element in method_args: if element not in expected_args: - unknown_args.append(element) + unexpected_args.append(element) else: missing_args.remove(element) - error_string = "" - if missing_args: - error_string += "missing: " + ", ".join(missing_args) - if unknown_args: - if error_string: - error_string += ", " - error_string += "unknown: " + ", ".join(unknown_args) - return method_args, error_string, missing_args + return method_args, unexpected_args, missing_args def create_decorator(self, node: ModifierDeclaration, resolved_values): available_modifiers = ["repeat", "inverter", "timeout", "retry", "failure_is_running", "failure_is_success", @@ -311,29 +304,42 @@ def visit_behavior_invocation(self, node: BehaviorInvocation): # - __init__(self) # - __init__(self, resolve_variable_reference_arguments_in_execute) # - __init__(self, ) - init_args, error_string, args_not_in_init = self.compare_method_arguments( + init_args, unexpected_args, args_not_in_init = self.compare_method_arguments( init_method, expected_args, behavior_name, node) if init_args != ["self"] and \ init_args != ["self", "resolve_variable_reference_arguments_in_execute"] and \ not all(x in expected_args for x in init_args): raise OSC2ParsingError( - msg=f'Plugin {behavior_name}: __init__() either only has "self" argument and osc-defined arguments. {error_string}\n' + msg=f'Plugin {behavior_name}: __init__() either only has "self" argument and osc-defined arguments. Unexpected args: {", ".join(unexpected_args)}\n' f'expected definition with all arguments: {expected_args}', context=node.get_ctx() ) execute_method = getattr(behavior_cls, "execute", None) - if execute_method is not None: - _, error_string, _ = self.compare_method_arguments(execute_method, expected_args, behavior_name, node) - if error_string: + if execute_method is None: + if args_not_in_init: raise OSC2ParsingError( - msg=f'Plugin {behavior_name}: execute() arguments differ from osc-definition: {error_string}.', context=node.get_ctx() - ) + msg=f'Plugin {behavior_name}: execute() required, but not defined. Required arguments (i.e. not defined in __init__()): {", ".join(args_not_in_init)}.', context=node.get_ctx()) + else: + expected_execute_args = copy.deepcopy(args_not_in_init) + expected_execute_args.append("self") + if node.actor: + expected_execute_args.append("associated_actor") + _, unexpected_execute_args, missing_execute_args = self.compare_method_arguments( + execute_method, expected_execute_args, behavior_name, node) + if missing_execute_args: + raise OSC2ParsingError( + msg=f'Plugin {behavior_name}: execute() is missing arguments: {", ".join(missing_execute_args)}. Either specify in __init__() or execute().', context=node.get_ctx()) + if unexpected_execute_args: + error = "" + if any(x in init_args for x in unexpected_execute_args): + error = " osc2 arguments, that are consumed in __init__() are not allowed to be used in execute() again. Please either remove argument(s) from __init__() or execute()." + raise OSC2ParsingError( + msg=f'Plugin {behavior_name}: execute() has unexpected arguments: {", ".join(unexpected_execute_args)}.{error}', context=node.get_ctx()) # initialize plugin instance action_name = node.name if not action_name: action_name = behavior_name - self.logger.debug( - f"Instantiate action '{action_name}', plugin '{behavior_name}'. with:\nExpected execute() arguments: {expected_args}") + self.logger.debug(f"Instantiate action '{action_name}', plugin '{behavior_name}'.") try: if init_args is not None and init_args != ['self'] and init_args != ['self', 'resolve_variable_reference_arguments_in_execute']: final_args = node.get_resolved_value(self.blackboard, skip_keys=args_not_in_init) @@ -342,8 +348,6 @@ def visit_behavior_invocation(self, node: BehaviorInvocation): final_args["associated_actor"] = node.actor.get_resolved_value(self.blackboard) final_args["associated_actor"]["name"] = node.actor.name - for k in args_not_in_init: - del final_args[k] instance = behavior_cls(**final_args) else: instance = behavior_cls() diff --git a/scenario_execution/scenario_execution/model/types.py b/scenario_execution/scenario_execution/model/types.py index 9148ef76..599e2d60 100644 --- a/scenario_execution/scenario_execution/model/types.py +++ b/scenario_execution/scenario_execution/model/types.py @@ -450,7 +450,9 @@ def get_resolved_value(self, blackboard=None, skip_keys=None): elif isinstance(child, MethodDeclaration): if child.name not in skip_keys: params[child.name] = child.get_resolved_value(blackboard) - + for k in skip_keys: + if k in params: + del params[k] return params def get_type(self): @@ -1520,7 +1522,9 @@ def get_base_type(self): def get_type(self): return self.behavior, False - def get_resolved_value_with_variable_references(self, blackboard): + def get_resolved_value_with_variable_references(self, blackboard, skip_keys=None): + if skip_keys is None: + skip_keys = [] params = self.get_resolved_value(blackboard) pos = 0 @@ -1528,12 +1532,16 @@ def get_resolved_value_with_variable_references(self, blackboard): for child in self.get_children(): if isinstance(child, PositionalArgument): if isinstance(child.get_child(0), IdentifierReference): - params[param_keys[pos]] = child.get_child(0).get_blackboard_reference(blackboard) + if param_keys[pos] not in skip_keys: + params[param_keys[pos]] = child.get_child(0).get_blackboard_reference(blackboard) pos += 1 elif isinstance(child, NamedArgument): if isinstance(child.get_child(0), IdentifierReference): - params[child.name] = child.get_child(0).get_blackboard_reference(blackboard) - + if child.name not in skip_keys: + params[child.name] = child.get_child(0).get_blackboard_reference(blackboard) + for k in skip_keys: + if k in params: + del params[k] return params diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py index 6701e418..1f2b339d 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py @@ -36,14 +36,14 @@ class AssertLifecycleStateState(Enum): class AssertLifecycleState(BaseAction): - def __init__(self, node_name: str, state_sequence: list, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool): + def __init__(self, node_name: str, state_sequence: list): super().__init__() self.current_state = AssertLifecycleStateState.IDLE self.node_name = node_name self.state_sequence = state_sequence - self.allow_initial_skip = allow_initial_skip - self.fail_on_unexpected = fail_on_unexpected - self.keep_running = keep_running + self.allow_initial_skip = None + self.fail_on_unexpected = None + self.keep_running = None self.node = None self.subscription = None self.initial_states_skipped = False @@ -65,10 +65,7 @@ def setup(self, **kwargs): service_get_state_name = "/" + self.node_name + "/get_state" self.client = self.node.create_client(GetState, service_get_state_name) - def execute(self, node_name: str, state_sequence: list, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool): - if self.node_name != node_name or self.state_sequence != state_sequence: - raise ActionError("Runtime change of arguments 'name', 'state_sequence not supported.", action=self) - + def execute(self, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool): if all(isinstance(state, tuple) and len(state) == 2 for state in self.state_sequence): self.state_sequence = [state[0] for state in self.state_sequence] else: diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py index 2b2164bc..9134f010 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py @@ -27,16 +27,16 @@ class AssertTfMoving(BaseAction): - def __init__(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, tf_topic_namespace: str, use_sim_time: bool): + def __init__(self, tf_topic_namespace: str): super().__init__() - self.frame_id = frame_id - self.parent_frame_id = parent_frame_id - self.timeout = timeout - self.threshold_translation = threshold_translation - self.threshold_rotation = threshold_rotation - self.wait_for_first_transform = wait_for_first_transform + self.frame_id = None + self.parent_frame_id = None + self.timeout = None + self.threshold_translation = None + self.threshold_rotation = None + self.wait_for_first_transform = None self.tf_topic_namespace = tf_topic_namespace - self.use_sim_time = use_sim_time + self.use_sim_time = None self.start_timeout = False self.timer = 0 self.transforms_received = 0 @@ -53,7 +53,6 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise ActionError(error_message, action=self) from e - self.feedback_message = f"Waiting for transform {self.parent_frame_id} --> {self.frame_id}" # pylint: disable= attribute-defined-outside-init self.tf_buffer = tf2_ros.Buffer() tf_prefix = self.tf_topic_namespace if not tf_prefix.startswith('/') and tf_prefix != '': @@ -65,9 +64,7 @@ def setup(self, **kwargs): tf_static_topic=(tf_prefix + "/tf_static"), ) - def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, tf_topic_namespace: str, use_sim_time: bool): - if self.tf_topic_namespace != tf_topic_namespace: - raise ActionError("Runtime change of argument 'tf_topic_namespace' not supported.", action=self) + def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, use_sim_time: bool): self.frame_id = frame_id self.parent_frame_id = parent_frame_id self.timeout = timeout @@ -75,6 +72,7 @@ def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_t self.threshold_rotation = threshold_rotation self.wait_for_first_transform = wait_for_first_transform self.use_sim_time = use_sim_time + self.feedback_message = f"Waiting for transform {self.parent_frame_id} --> {self.frame_id}" # pylint: disable= attribute-defined-outside-init def update(self) -> py_trees.common.Status: now = time.time() diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py index 7ac038d7..8c34a136 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py @@ -59,7 +59,7 @@ def setup(self, **kwargs): elif not success and not self.wait_for_first_message: raise ActionError("Topic type must be specified. Please provide a valid topic type.", action=self) - def execute(self, topic_name: str, topic_type: str, latency: float, comparison_operator: bool, rolling_average_count: int, wait_for_first_message: bool): + def execute(self): if self.timer != 0: raise ActionError("Action does not yet support to get retriggered", action=self) self.timer = time.time() diff --git a/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py b/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py index e2fa6b52..5f455265 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py @@ -27,10 +27,10 @@ class OdometryDistanceTraveled(BaseAction): Class to wait for a certain covered distance, based on odometry """ - def __init__(self, associated_actor, distance: float, namespace_override: str): + def __init__(self, associated_actor, namespace_override: str): super().__init__() self.namespace = associated_actor["namespace"] - self.distance_expected = distance + self.distance_expected = None self.distance_traveled = 0.0 self.previous_x = 0 self.previous_y = 0 @@ -57,8 +57,8 @@ def setup(self, **kwargs): self.subscriber = self.node.create_subscription( Odometry, namespace + '/odom', self._callback, 1000, callback_group=self.callback_group) - def execute(self, associated_actor, distance: float, namespace_override: str): - if self.namespace != associated_actor["namespace"] or self.namespace_override != namespace_override: + def execute(self, associated_actor, distance: float): + if self.namespace != associated_actor["namespace"] and not self.namespace_override: raise ActionError("Runtime change of namespace not supported.", action=self) self.distance_expected = distance self.distance_traveled = 0.0 diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py index d8dcd8e6..7e3965a6 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py @@ -43,7 +43,7 @@ class RosActionCall(BaseAction): ros service call behavior """ - def __init__(self, action_name: str, action_type: str, data: str, transient_local: bool = False): + def __init__(self, action_name: str, action_type: str, transient_local: bool = False): super().__init__() self.node = None self.client = None @@ -54,7 +54,6 @@ def __init__(self, action_name: str, action_type: str, data: str, transient_loca self.action_name = action_name self.received_feedback = None self.data = None - self.parse_data(data) self.current_state = ActionCallActionState.IDLE self.cb_group = ReentrantCallbackGroup() self.transient_local = transient_local @@ -90,10 +89,7 @@ def setup(self, **kwargs): self.client = ActionClient(self.node, self.action_type, self.action_name, **client_kwargs) - def execute(self, action_name: str, action_type: str, data: str, transient_local: bool = False): - if self.action_name != action_name or self.action_type_string != action_type or self.transient_local != transient_local: - raise ActionError(f"Updating action_name or action_type_string not supported.", action=self) - + def execute(self, data: str): self.parse_data(data) self.current_state = ActionCallActionState.IDLE diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_bag_record.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_bag_record.py index f10659db..e47faec1 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_bag_record.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_bag_record.py @@ -39,12 +39,13 @@ class RosBagRecord(RunProcess): Class to execute ros bag recording """ - def __init__(self, topics: list, timestamp_suffix: bool, hidden_topics: bool, storage: str, use_sim_time: bool): + def __init__(self): super().__init__() self.bag_dir = None self.current_state = RosBagRecordActionState.WAITING_FOR_TOPICS self.command = None self.output_dir = None + self.topics = None def setup(self, **kwargs): """ diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_launch.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_launch.py index bfb72396..4e460df4 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_launch.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_launch.py @@ -21,9 +21,6 @@ class RosLaunch(RunProcess): - def __init__(self, package_name: str, launch_file: str, arguments: list, wait_for_shutdown: bool, shutdown_timeout: float): - super().__init__(None, wait_for_shutdown, shutdown_timeout, shutdown_signal=("", signal.SIGINT)) - def execute(self, package_name: str, launch_file: str, arguments: list, wait_for_shutdown: bool, shutdown_timeout: float): # pylint: disable=arguments-differ super().execute(None, wait_for_shutdown, shutdown_timeout, shutdown_signal=("", signal.SIGINT)) self.command = ["ros2", "launch", package_name, launch_file] diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py index dc557f92..49eb9288 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py @@ -27,17 +27,14 @@ class RosLogCheck(BaseAction): Class for scanning the ros log for specific content """ - def __init__(self, values: list, module_name: str): + def __init__(self): super().__init__() - if not isinstance(values, list): - raise TypeError(f'Value needs to be list of strings, got {type(values)}.') - else: - self.values = values + self.values = None self.subscriber = None self.node = None self.found = None - self.module_name = module_name + self.module_name = None def setup(self, **kwargs): """ diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py index 0f84a27d..d10a6bb7 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py @@ -40,7 +40,7 @@ class RosServiceCall(BaseAction): ros service call behavior """ - def __init__(self, service_name: str, service_type: str, data: str, transient_local: bool = False): + def __init__(self, service_name: str, service_type: str, transient_local: bool = False): super().__init__() self.node = None self.client = None @@ -48,12 +48,7 @@ def __init__(self, service_name: str, service_type: str, data: str, transient_lo self.service_type_str = service_type self.service_type = None self.service_name = service_name - self.data_str = data - try: - trimmed_data = self.data_str.encode('utf-8').decode('unicode_escape') - self.data = literal_eval(trimmed_data) - except Exception as e: # pylint: disable=broad-except - raise ValueError(f"Error while parsing sevice call data:") from e + self.data = None self.current_state = ServiceCallActionState.IDLE self.cb_group = ReentrantCallbackGroup() self.transient_local = transient_local @@ -93,9 +88,12 @@ def setup(self, **kwargs): **client_kwargs ) - def execute(self, service_name: str, service_type: str, data: str, transient_local: bool): - if self.service_name != service_name or self.service_type_str != service_type or self.data_str != data or self.transient_local != transient_local: - raise ActionError("service_name, service_type and data arguments are not changeable during runtime.", action=self) + def execute(self, data: str): + try: + trimmed_data = data.encode('utf-8').decode('unicode_escape') + self.data = literal_eval(trimmed_data) + except Exception as e: # pylint: disable=broad-except + raise ActionError(f"Error while parsing sevice call data: {e}", action=self) from e self.current_state = ServiceCallActionState.IDLE def update(self) -> py_trees.common.Status: diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py index dfda1c66..5c87e947 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py @@ -18,7 +18,6 @@ from .ros_service_call import RosServiceCall from rcl_interfaces.msg import ParameterType -from scenario_execution.actions.base_action import ActionError class RosSetNodeParameter(RosServiceCall): @@ -26,14 +25,16 @@ class RosSetNodeParameter(RosServiceCall): class for setting a node parameter """ - def __init__(self, node_name: str, parameter_name: str, parameter_value: str): + def __init__(self, node_name: str, parameter_name: str): self.node_name = node_name self.parameter_name = parameter_name - self.parameter_value = parameter_value + self.parameter_value = None service_name = node_name + '/set_parameters' if not service_name.startswith('/'): service_name = '/' + service_name + super().__init__(service_name=service_name, service_type='rcl_interfaces.srv.SetParameters') + def execute(self, parameter_value: str): # pylint: disable=arguments-differ,arguments-renamed parameter_type = ParameterType.PARAMETER_STRING parameter_assign_name = 'string_value' if parameter_value.lower() == 'true' or parameter_value.lower() == 'false': @@ -64,14 +65,8 @@ def __init__(self, node_name: str, parameter_name: str, parameter_value: str): else: parameter_type = ParameterType.PARAMETER_STRING_ARRAY parameter_assign_name = 'string_array_value' - - super().__init__(service_name=service_name, - service_type='rcl_interfaces.srv.SetParameters', - data='{ "parameters": [{ "name": "' + parameter_name + '", "value": { "type": ' + str(parameter_type) + ', "' + parameter_assign_name + '": ' + parameter_value + '}}]}') - - def execute(self, node_name: str, parameter_name: str, parameter_value: str): # pylint: disable=arguments-differ,arguments-renamed - if self.node_name != node_name or self.parameter_name != parameter_name or self.parameter_value != parameter_value: - raise ActionError("node_name, parameter_name and parameter_value are not changeable during runtime.", action=self) + super().execute(data='{ "parameters": [{ "name": "' + self.parameter_name + '", "value": { "type": ' + + str(parameter_type) + ', "' + parameter_assign_name + '": ' + parameter_value + '}}]}') @staticmethod def is_float(element: any) -> bool: diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py index 74fa9e84..517c4fc7 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py @@ -32,23 +32,18 @@ class RosTopicCheckData(BaseAction): def __init__(self, topic_name: str, topic_type: str, - qos_profile: tuple, member_name: str, - expected_value: str, - comparison_operator: int, - fail_if_no_data: bool, - fail_if_bad_comparison: bool, - wait_for_first_message: bool): + qos_profile: tuple): super().__init__() self.topic_name = topic_name self.topic_type = topic_type - self.qos_profile = qos_profile self.member_name = member_name - self.set_expected_value(expected_value) - self.comparison_operator = get_comparison_operator(comparison_operator) - self.fail_if_no_data = fail_if_no_data - self.fail_if_bad_comparison = fail_if_bad_comparison - self.wait_for_first_message = wait_for_first_message + self.qos_profile = qos_profile + self.expected_value = None + self.comparison_operator = None + self.fail_if_no_data = None + self.fail_if_bad_comparison = None + self.wait_for_first_message = None self.last_msg = None self.found = None @@ -63,6 +58,14 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise ActionError(error_message, action=self) from e + #check if msg type exists and has member + try: + msg = get_ros_message_type(self.topic_type)() + if self.member_name: + getattr(msg, self.member_name) + except (ValueError, AttributeError) as e: + raise ActionError(f"Member '{self.member_name}' not found in topic type '{self.topic_type}'.", action=self) from e + self.subscriber = self.node.create_subscription( msg_type=get_ros_message_type(self.topic_type), topic=self.topic_name, @@ -73,18 +76,11 @@ def setup(self, **kwargs): self.feedback_message = f"Waiting for data on {self.topic_name}" # pylint: disable= attribute-defined-outside-init def execute(self, - topic_name: str, - topic_type: str, - qos_profile: tuple, - member_name: str, expected_value: str, comparison_operator: int, fail_if_no_data: bool, fail_if_bad_comparison: bool, wait_for_first_message: bool): - if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: - raise ActionError("Updating topic parameters not supported.", action=self) - self.member_name = member_name self.set_expected_value(expected_value) self.comparison_operator = get_comparison_operator(comparison_operator) self.fail_if_no_data = fail_if_no_data @@ -115,7 +111,7 @@ def _callback(self, msg): self.feedback_message = f"Received message does not contain expected value." def check_data(self, msg): - if msg is None: + if msg is None or self.member_name is None or self.expected_value is None: return if self.member_name == "": @@ -145,5 +141,5 @@ def set_expected_value(self, expected_value_string): self.expected_value = parsed_value else: set_message_fields(self.expected_value, parsed_value) - except TypeError as e: + except (TypeError, AttributeError) as e: raise ActionError(f"Could not parse '{expected_value_string}'. {error_string}", action=self) from e diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py index 714baa29..74d73958 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py @@ -24,10 +24,10 @@ class RosTopicMonitor(BaseAction): - def __init__(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): + def __init__(self, topic_name: str, topic_type: str, qos_profile: tuple): super().__init__(resolve_variable_reference_arguments_in_execute=False) self.target_variable = None - self.member_name = member_name + self.member_name = None self.topic_type = topic_type self.qos_profile = qos_profile self.topic_name = topic_name @@ -59,9 +59,7 @@ def setup(self, **kwargs): ) self.feedback_message = f"Monitoring data on {self.topic_name}" # pylint: disable= attribute-defined-outside-init - def execute(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): - if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: - raise ActionError("Updating topic parameters not supported.", action=self) + def execute(self, member_name: str, target_variable: object): if not isinstance(target_variable, VariableReference): raise ActionError(f"'target_variable' is expected to be a variable reference.", action=self) self.target_variable = target_variable @@ -75,7 +73,7 @@ def _callback(self, msg): self.target_variable.set_value(self.get_value(msg)) def get_value(self, msg): - if self.member_name != "": + if self.member_name is not None and self.member_name != "": check_attr = operator.attrgetter(self.member_name) try: return check_attr(msg) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_publish.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_publish.py index def6cad2..05fb2b49 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_publish.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_publish.py @@ -61,10 +61,7 @@ def setup(self, **kwargs): except ValueError as e: raise ActionError(f"{e}", action=self) from e - def execute(self, topic_type: str, topic_name: str, value: str, qos_profile: tuple): - if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: - raise ActionError("Updating topic parameters not supported.", action=self) - + def execute(self, value: str): if isinstance(value, str): parsed_value = literal_eval("".join(value.split('\\'))) if not isinstance(parsed_value, dict): diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_data.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_data.py index e0bec33f..9686cb1c 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_data.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_wait_for_data.py @@ -59,9 +59,7 @@ def setup(self, **kwargs): ) self.feedback_message = f"Waiting for data on {self.topic_name}" # pylint: disable= attribute-defined-outside-init - def execute(self, topic_name, topic_type, qos_profile): - if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: - raise ActionError("Updating topic parameters not supported.", action=self) + def execute(self): self.found = False def update(self) -> py_trees.common.Status: diff --git a/scenario_execution_ros/test/test_check_data.py b/scenario_execution_ros/test/test_check_data.py index c3e33066..ad6fde24 100644 --- a/scenario_execution_ros/test/test_check_data.py +++ b/scenario_execution_ros/test/test_check_data.py @@ -22,6 +22,7 @@ from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.model.model_to_py_tree import create_py_tree from scenario_execution.utils.logging import Logger +from scenario_execution.actions.base_action import ActionError from antlr4.InputStream import InputStream @@ -91,6 +92,31 @@ def test_success_member(self): self.execute(scenario_content) self.assertTrue(self.scenario_execution_ros.process_results()) + def test_fail_unknown_type(self): + scenario_content = """ +import osc.ros + +scenario test: + do parallel: + test: serial: + wait elapsed(1s) + topic_publish( + topic_name: '/bla', + topic_type: 'std_msgs.msg.Bool', + value: '{\\\"data\\\": True}') + receive: serial: + check_data( + topic_name: '/bla', + topic_type: 'std_msgs.msg.UNKNOWN', + expected_value: 'True') + emit end + time_out: serial: + wait elapsed(10s) + emit fail +""" + self.execute(scenario_content) + self.assertFalse(self.scenario_execution_ros.process_results()) + def test_fail_unknown_member(self): scenario_content = """ import osc.ros @@ -114,9 +140,8 @@ def test_fail_unknown_member(self): wait elapsed(10s) emit fail """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) - self.assertRaises(ValueError, create_py_tree, model, self.tree, self.parser.logger, False) + self.execute(scenario_content) + self.assertFalse(self.scenario_execution_ros.process_results()) def test_fail_member_value_differ(self): scenario_content = """