Skip to content

Commit

Permalink
Initialize osc args only once, either in init or execute
Browse files Browse the repository at this point in the history
  • Loading branch information
fred-labs committed Aug 28, 2024
1 parent a426b68 commit 7a090bc
Show file tree
Hide file tree
Showing 36 changed files with 210 additions and 211 deletions.
5 changes: 3 additions & 2 deletions docs/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ Implement an Action
- Make use of ``kwargs['logger']``, available in ``setup()``
- If you want to draw markers for RViz, use ``kwargs['marker_handler']``, available in ``setup()`` (with ROS backend)
- Use arguments from ``__init__()`` for a longer running initialization in ``setup()`` and the arguments from ``execute()`` to set values just before executing the action.
- ``__init__()`` does not need to contain all osc2-defined arguments. This can be convenient as variable argument resolving might not be available during ``__init__()``.
- ``execute()`` contains all osc2-arguments.
- ``__init__()`` and ``setup()`` are called once, ``execute()`` might be called multiple times.
- osc2 arguments can only be consumed once, either in ``__init__()`` or ``execute()``. Exception: If an ``associated_actor`` exists, it's an argument of both methods.
- Arguments, that need late resolving (e.g. refering to variables or extenral methods), specify in ``execute()``.
- ``setup()`` provides several arguments that might be useful:
- ``input_dir``: Directory containing the scenario file
- ``output_dir``: If given on command-line, contains the directory to save output to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class CustomAction(BaseAction):

def __init__(self, data: str): # get action arguments, at the time of initialization
def __init__(self): # get action arguments, at the time of initialization
super().__init__()

def execute(self, data: str): # get action arguments, at the time of execution (may got updated during scenario execution)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class GenerateGazeboWorld(BaseAction):

def __init__(self, associated_actor, sdf_template: str, arguments: list):
def __init__(self, associated_actor, sdf_template: str):
super().__init__()
self.sdf_template = sdf_template
self.spawn_utils = SpawnUtils(self.logger)
Expand All @@ -42,7 +42,7 @@ def setup(self, **kwargs):
raise ActionError(f"SDF Template {self.sdf_template} not found.", action=self)
self.tmp_file = tempfile.NamedTemporaryFile(suffix=".sdf") # for testing, do not delete temp file: delete=False

def execute(self, associated_actor, sdf_template: str, arguments: list):
def execute(self, associated_actor, arguments: list):
self.arguments_string = ""
for elem in arguments:
self.arguments_string += f'{elem["key"]}:={elem["value"]}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class GazeboActorExists(RunProcess):
"""

def __init__(self, entity_name: str, world_name: str):
def __init__(self):
super().__init__()
self.entity_name = None
self.current_state = ActorExistsActionState.IDLE

def execute(self, entity_name: str, world_name: str): # pylint: disable=arguments-differ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class GazeboDeleteActor(RunProcess):
"""

def __init__(self, associated_actor, entity_name: str, world_name: str):
def __init__(self, associated_actor):
super().__init__()
self.entity_name = None
self.current_state = DeleteActionState.IDLE

def execute(self, associated_actor, entity_name: str, world_name: str): # pylint: disable=arguments-differ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,17 @@ class GazeboRelativeSpawnActor(GazeboSpawnActor):
"""

def __init__(self, associated_actor,
frame_id: str, parent_frame_id: str,
distance: float, world_name: str, xacro_arguments: list,
model: str):
super().__init__(associated_actor, None, world_name, xacro_arguments, model)
def __init__(self, associated_actor, xacro_arguments: list, model: str):
super().__init__(associated_actor, xacro_arguments, model)
self._pose = '{}'
self.model = model
self.world_name = None
self.xacro_arguments = xacro_arguments
self.tf_buffer = Buffer()
self.tf_listener = None

def execute(self, associated_actor, # pylint: disable=arguments-differ
frame_id: str, parent_frame_id: str,
distance: float, world_name: str, xacro_arguments: list,
model: str):
super().execute(associated_actor, None, world_name, xacro_arguments, model)
def execute(self, associated_actor, frame_id: str, parent_frame_id: str, distance: float, world_name: str): # pylint: disable=arguments-differ
super().execute(associated_actor, None, world_name)
self.frame_id = frame_id
self.parent_frame_id = parent_frame_id
self.distance = distance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GazeboSpawnActor(RunProcess):
"""

def __init__(self, associated_actor, spawn_pose: list, world_name: str, xacro_arguments: list, model: str):
def __init__(self, associated_actor, xacro_arguments: list, model: str):
"""
init
"""
Expand Down Expand Up @@ -92,9 +92,7 @@ def setup(self, **kwargs):
raise ActionError(f'Invalid model specified ({self.entity_model})', action=self)
self.current_state = SpawnActionState.MODEL_AVAILABLE

def execute(self, associated_actor, spawn_pose: list, world_name: str, xacro_arguments: list, model: str): # pylint: disable=arguments-differ
if self.entity_model != model or set(self.xacro_arguments) != set(xacro_arguments):
raise ActionError("Runtime change of model not supported.", action=self)
def execute(self, associated_actor, spawn_pose: list, world_name: str): # pylint: disable=arguments-differ
self.spawn_pose = spawn_pose
self.world_name = world_name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GazeboWaitForSim(RunProcess):
Class to wait for the simulation to become active
"""

def __init__(self, world_name: str, timeout: int):
def __init__(self):
super().__init__()
self.current_state = WaitForSimulationActionState.IDLE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class KubernetesBaseActionState(Enum):

class KubernetesBaseAction(BaseAction):

def __init__(self, namespace: str, within_cluster: bool):
def __init__(self, within_cluster: bool):
super().__init__()
self.namespace = namespace
self.namespace = None
self.within_cluster = within_cluster
self.client = None
self.current_state = KubernetesBaseActionState.IDLE
Expand All @@ -44,10 +44,8 @@ def setup(self, **kwargs):
config.load_kube_config()
self.client = client.CoreV1Api()

def execute(self, namespace: str, within_cluster: bool):
def execute(self, namespace: str):
self.namespace = namespace
if within_cluster != self.within_cluster:
raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.")

def update(self) -> py_trees.common.Status: # pylint: disable=too-many-return-statements
if self.current_state == KubernetesBaseActionState.IDLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@

class KubernetesPatchPod(KubernetesBaseAction):

def __init__(self, namespace: str, target: str, body: str, within_cluster: bool):
super().__init__(namespace, within_cluster)
self.target = target
def __init__(self, within_cluster: bool):
super().__init__(within_cluster)
self.namespace = None
self.target = None
self.body = None

def execute(self, namespace: str, target: str, body: str, within_cluster: bool): # pylint: disable=arguments-differ
super().execute(namespace, within_cluster)
def execute(self, namespace: str, target: str, body: str): # pylint: disable=arguments-differ
super().execute(namespace)
self.target = target
trimmed_data = body.encode('utf-8').decode('unicode_escape')
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class KubernetesPodExecState(Enum):

class KubernetesPodExec(BaseAction):

def __init__(self, target: str, command: list, regex: bool, namespace: str, within_cluster: bool):
def __init__(self, within_cluster: bool):
super().__init__()
self.target = target
self.namespace = namespace
self.regex = regex
self.command = command
self.target = None
self.namespace = None
self.regex = None
self.command = None
self.within_cluster = within_cluster
self.client = None
self.reponse_queue = queue.Queue()
Expand All @@ -56,9 +56,7 @@ def setup(self, **kwargs):

self.exec_thread = threading.Thread(target=self.pod_exec, daemon=True)

def execute(self, target: str, command: list, regex: bool, namespace: str, within_cluster: bool):
if within_cluster != self.within_cluster:
raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.")
def execute(self, target: str, command: list, regex: bool, namespace: str):
self.target = target
self.namespace = namespace
self.command = command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setup(self, **kwargs):
self.k8s_client = client.api_client.ApiClient()
self.network_client = client.NetworkingV1Api(self.k8s_client)

def execute(self, target: str, status: tuple, namespace: str, within_cluster: bool):
def execute(self):
self.monitoring_thread = threading.Thread(target=self.watch_network, daemon=True)
self.monitoring_thread.start()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ class KubernetesWaitForPodStatusState(Enum):

class KubernetesWaitForPodStatus(BaseAction):

def __init__(self, target: str, regex: bool, status: tuple, namespace: str, within_cluster: bool):
def __init__(self, within_cluster: bool):
super().__init__()
self.target = target
self.namespace = namespace
if not isinstance(status, tuple) or not isinstance(status[0], str):
raise ValueError("Status expected to be enum.")
self.expected_status = status[0]
self.target = None
self.namespace = None
self.expected_status = None
self.within_cluster = within_cluster
self.regex = regex
self.regex = None
self.client = None
self.update_queue = queue.Queue()
self.current_state = KubernetesWaitForPodStatusState.IDLE
Expand All @@ -55,11 +53,12 @@ def setup(self, **kwargs):
self.monitoring_thread = threading.Thread(target=self.watch_pods, daemon=True)
self.monitoring_thread.start()

def execute(self, target: str, regex: bool, status: tuple, namespace: str, within_cluster: bool):
def execute(self, target: str, regex: bool, status: tuple, namespace: str):
self.target = target
self.namespace = namespace
if not isinstance(status, tuple) or not isinstance(status[0], str):
raise ValueError("Status expected to be enum.")
self.expected_status = status[0]
self.within_cluster = within_cluster
self.regex = regex
self.current_state = KubernetesWaitForPodStatusState.MONITORING

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ class InitNav2(BaseAction):
"""

def __init__(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool, namespace_override: str):
def __init__(self, associated_actor, namespace_override: str):
super().__init__()
self.initial_pose = initial_pose
self.base_frame_id = base_frame_id
self.wait_for_initial_pose = wait_for_initial_pose
self.use_initial_pose = use_initial_pose
self.initial_pose = None
self.base_frame_id = None
self.wait_for_initial_pose = None
self.use_initial_pose = None
self.namespace = associated_actor["namespace"]
self.node = None
self.future = None
Expand Down Expand Up @@ -118,14 +118,12 @@ def setup(self, **kwargs):
amcl_pose_qos,
callback_group=ReentrantCallbackGroup())

def execute(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool, namespace_override: str):
def execute(self, associated_actor, initial_pose: list, base_frame_id: str, wait_for_initial_pose: bool, use_initial_pose: bool):
self.initial_pose = initial_pose
self.base_frame_id = base_frame_id
self.wait_for_initial_pose = wait_for_initial_pose
self.use_initial_pose = use_initial_pose
self.namespace = associated_actor["namespace"]
if namespace_override:
self.namespace = namespace_override

def update(self) -> py_trees.common.Status:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,16 @@ class NavThroughPoses(RosActionCall):
Class to navigate through poses
"""

def __init__(self, associated_actor, goal_poses: list, action_topic: str, namespace_override: str):
def __init__(self, associated_actor, action_topic: str, namespace_override: str):
self.namespace = associated_actor["namespace"]
if namespace_override:
self.namespace = namespace_override
self.goal_poses = None
super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses", "")
super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses")

def execute(self, associated_actor, goal_poses: list, action_topic: str, namespace_override: str) -> None: # pylint: disable=arguments-differ,arguments-renamed
self.namespace = associated_actor["namespace"]
if namespace_override:
self.namespace = namespace_override
def execute(self, associated_actor, goal_poses: list) -> None: # pylint: disable=arguments-differ,arguments-renamed
self.goal_poses = goal_poses
super().execute(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateThroughPoses", "")
super().execute("")

def get_goal_msg(self):
goal_msg = NavigateThroughPoses.Goal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,16 @@ class NavToPose(RosActionCall):
Class to navigate to a pose
"""

def __init__(self, associated_actor, goal_pose: list, action_topic: str, namespace_override: str) -> None:
def __init__(self, associated_actor, action_topic: str, namespace_override: str) -> None:
self.namespace = associated_actor["namespace"]
if namespace_override:
self.namespace = namespace_override
self.goal_pose = None
super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose", "")
super().__init__(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose")

def execute(self, associated_actor, goal_pose: list, action_topic: str, namespace_override: str) -> None: # pylint: disable=arguments-differ,arguments-renamed
self.namespace = associated_actor["namespace"]
if namespace_override:
self.namespace = namespace_override
def execute(self, associated_actor, goal_pose: list) -> None: # pylint: disable=arguments-differ,arguments-renamed
self.goal_pose = goal_pose
super().execute(self.namespace + '/' + action_topic, "nav2_msgs.action.NavigateToPose", "")
super().execute("")

def get_goal_msg(self):
goal_msg = NavigateToPose.Goal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class CaptureScreenState(Enum):

class CaptureScreen(RunProcess):

def __init__(self, output_filename: str, frame_rate: float):
super().__init__("", wait_for_shutdown=True)
def __init__(self):
super().__init__()
self.current_state = None
self.output_dir = "."

Expand All @@ -46,6 +46,7 @@ def setup(self, **kwargs):
self.output_dir = kwargs['output_dir']

def execute(self, output_filename: str, frame_rate: float): # pylint: disable=arguments-differ
super().execute(None, wait_for_shutdown=True)
self.current_state = CaptureScreenState.IDLE
cmd = ["ffmpeg",
"-f", "x11grab",
Expand Down
21 changes: 14 additions & 7 deletions scenario_execution/scenario_execution/actions/base_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,29 @@
import py_trees
from scenario_execution.model.types import ParameterDeclaration, ScenarioDeclaration
from scenario_execution.model.error import OSC2Error
import inspect


class BaseAction(py_trees.behaviour.Behaviour):

# subclasses might implement __init__() with the same arguments as defined in osc
# subclasses might implement __init__() with osc2 arguments as required
# CAUTION: __init__() only gets the initial parameter values. In case variables get modified during scenario execution,
# the latest values are available in execute() only.
def __init__(self, resolve_variable_reference_arguments_in_execute=True):
self._model = None
self.logger = None
self.blackboard = None
self.resolve_variable_reference_arguments_in_execute = resolve_variable_reference_arguments_in_execute

execute_method = getattr(self, "execute", None)
if execute_method is not None and callable(execute_method):
self.execute_method = execute_method
self.execute_skip_args = inspect.getfullargspec(getattr(self, "__init__", None)).args
else:
self.execute_method = None
super().__init__(self.__class__.__name__)

# Subclasses might implement execute() with the same arguments as defined in osc.
# Subclasses might implement execute() with the osc2 arguments that are not used within __init__().
# def execute(self):

# Subclasses might override shutdown() in order to cleanup on scenario shutdown.
Expand All @@ -43,13 +51,12 @@ def shutdown(self):
#############

def initialise(self):
execute_method = getattr(self, "execute", None)
if execute_method is not None and callable(execute_method):

if self.execute_method is not None:
if self.resolve_variable_reference_arguments_in_execute:
final_args = self._model.get_resolved_value(self.get_blackboard_client())
final_args = self._model.get_resolved_value(self.get_blackboard_client(), skip_keys=self.execute_skip_args)
else:
final_args = self._model.get_resolved_value_with_variable_references(self.get_blackboard_client())
final_args = self._model.get_resolved_value_with_variable_references(
self.get_blackboard_client(), skip_keys=self.execute_skip_args)

if self._model.actor:
final_args["associated_actor"] = self._model.actor.get_resolved_value(self.get_blackboard_client())
Expand Down
Loading

0 comments on commit 7a090bc

Please sign in to comment.