Skip to content

Commit

Permalink
Raise an exception explicitly when the timeout is exceeded (#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 authored Dec 13, 2024
1 parent 575106e commit ea54c16
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 19 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.while" = "aiida_workgraph.tasks.builtins:While"
"workgraph.if" = "aiida_workgraph.tasks.builtins:If"
"workgraph.select" = "aiida_workgraph.tasks.builtins:Select"
"workgraph.gather" = "aiida_workgraph.tasks.builtins:Gather"
"workgraph.set_context" = "aiida_workgraph.tasks.builtins:SetContext"
"workgraph.get_context" = "aiida_workgraph.tasks.builtins:GetContext"
"workgraph.time_monitor" = "aiida_workgraph.tasks.monitors:TimeMonitor"
Expand Down
1 change: 0 additions & 1 deletion src/aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:
self.ctx._tasks = wgdata["tasks"]
self.ctx._links = wgdata["links"]
self.ctx._connectivity = wgdata["connectivity"]
self.ctx._ctrl_links = wgdata["ctrl_links"]
self.ctx._workgraph = wgdata
self.ctx._error_handlers = wgdata["error_handlers"]

Expand Down
7 changes: 6 additions & 1 deletion src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ def __init__(
"""
self.process = process
self.restart_process = restart_process
wgdata.setdefault("uuid", "")
wgdata.setdefault("tasks", {})
wgdata.setdefault("links", [])
wgdata.setdefault("ctrl_links", [])
wgdata.setdefault("error_handlers", {})
wgdata.setdefault("context", {})
self.wgdata = wgdata
self.uuid = wgdata["uuid"]
self.name = wgdata["name"]
self.wait_to_link()
self.clean_hanging_links()
Expand Down
26 changes: 21 additions & 5 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def submit(
self,
inputs: Optional[Dict[str, Any]] = None,
wait: bool = False,
timeout: int = 60,
timeout: int = 600,
interval: int = 5,
metadata: Optional[Dict[str, Any]] = None,
) -> aiida.orm.ProcessNode:
"""Submit the AiiDA workgraph process and optionally wait for it to finish.
Args:
wait (bool): Wait for the process to finish.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 60.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks.
new (bool): Submit a new process.
"""
Expand Down Expand Up @@ -228,11 +228,17 @@ def get_error_handlers(self) -> Dict[str, Any]:
task["exit_codes"] = exit_codes
return error_handlers

def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None:
def wait(self, timeout: int = 600, tasks: dict = None, interval: int = 5) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
tasks (dict): Optional; specifies task states to wait for in the format {task_name: [acceptable_states]}.
interval (int): The time interval in seconds between checks. Defaults to 5.
Raises:
TimeoutError: If the process does not finish within the given timeout.
"""
terminating_states = (
"KILLED",
Expand All @@ -245,8 +251,10 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
start = time.time()
self.update()
finished = False

while not finished:
self.update()

if tasks is not None:
states = []
for name, value in tasks.items():
Expand All @@ -255,9 +263,17 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
finished = all(states)
else:
finished = self.state in terminating_states

if finished:
print(f"Process {self.process.pk} finished with state: {self.state}")
return

time.sleep(interval)

if time.time() - start > timeout:
break
raise TimeoutError(
f"Timeout reached after {timeout} seconds while waiting for the WorkGraph: {self.process.pk}. "
)

def update(self) -> None:
"""
Expand Down
49 changes: 40 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,43 @@ def wg_engine(decorated_add, add_code) -> WorkGraph:


@pytest.fixture
def finished_process_node():
"""Return a finished process node."""

node = WorkflowNode()
node.set_process_state("finished")
node.set_exit_status(0)
node.seal()
node.store()
return node
def create_process_node():
"""Return a process node."""

def process_node(state="finished", exit_status=0):
"""Return a finished process node."""

node = WorkflowNode()
node.set_process_state(state)
node.set_exit_status(exit_status)
node.seal()
node.store()
return node

return process_node


@pytest.fixture
def create_workgraph_process_node():
"""Return a process node."""

def process_node(state="finished", exit_status=0):
"""Return a finished process node."""
from aiida_workgraph.engine.workgraph import WorkGraphEngine

process = WorkGraphEngine(
inputs={
"wg": {
"name": "test",
"state": "",
}
}
)
node = process.node
node.set_process_state(state)
node.set_exit_status(exit_status)
node.seal()
node.store()
return node

return process_node
4 changes: 2 additions & 2 deletions tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_pause_play_task(wg_calcjob):
assert wg.tasks.add2.outputs.sum.value == 9


def test_pause_play_error_handler(wg_calcjob, finished_process_node):
def test_pause_play_error_handler(wg_calcjob, create_process_node):
wg = wg_calcjob
wg.name = "test_pause_play_error_handler"
wg.process = finished_process_node
wg.process = create_process_node(state="finished", exit_status=0)
try:
wg.pause_tasks(["add1"])
except Exception as e:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,14 @@ def test_workgraph_group_outputs(decorated_add):
wg.run()
assert wg.process.outputs.sum.value == 5
# assert wg.process.outputs.add1.result.value == 5


@pytest.mark.usefixtures("started_daemon_client")
def test_wait_timeout(create_workgraph_process_node):
wg = WorkGraph()
wg.process = create_workgraph_process_node(state="running")
with pytest.raises(
TimeoutError,
match="Timeout reached after 1 seconds while waiting for the WorkGraph:",
):
wg.wait(timeout=1, interval=1)

0 comments on commit ea54c16

Please sign in to comment.