Skip to content

Commit

Permalink
Add new API logic for get or create task run info
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw committed Nov 29, 2020
1 parent 8f6e128 commit 87e5d96
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 52 deletions.
111 changes: 59 additions & 52 deletions src/prefect_server/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,67 +268,74 @@ async def get_or_create_task_run(
raise ValueError("Invalid ID")


@register_api("runs.get_or_create_mapped_task_run_children")
async def get_or_create_mapped_task_run_children(
flow_run_id: str, task_id: str, max_map_index: int
) -> List[str]:
@register_api("runs.get_or_create_task_run_info")
async def get_or_create_task_run_info(
flow_run_id: str, task_id: str, map_index: int = None
) -> dict:
"""
Creates and/or retrieves mapped child task runs for a given flow run and task.
Given a flow_run_id, task_id, and map_index, return details about the corresponding task run.
If the task run doesn't exist, it will be created.
Args:
- flow_run_id (str): the flow run associated with the parent task run
- task_id (str): the task ID to create and/or retrieve
- max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children
Returns:
- dict: a dict of details about the task run, including its id, version, and state.
"""
# grab task info
task = await models.Task.where(id=task_id).first({"cache_key", "tenant_id"})
# generate task runs to upsert
task_runs = [
models.TaskRun(
tenant_id=task.tenant_id,
flow_run_id=flow_run_id,
task_id=task_id,
map_index=i,
cache_key=task.cache_key,
)
for i in range(max_map_index + 1)
]
# upsert the mapped children
task_runs = (
await models.TaskRun().insert_many(
objects=task_runs,
on_conflict=dict(
constraint="task_run_unique_identifier_key",
update_columns=["cache_key"],
),
selection_set={"returning": {"id", "map_index"}},
)
)["returning"]
task_runs.sort(key=lambda task_run: task_run.map_index)
# get task runs without states
stateless_runs = await models.TaskRun.where(

if map_index is None:
map_index = -1

task_run = await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
# this syntax indicates "where there are no states"
"_not": {"states": {}},
"map_index": {"_eq": map_index},
}
).get({"id", "map_index", "version"})
# create and insert states for stateless task runs
task_run_states = [
models.TaskRunState(
tenant_id=task.tenant_id,
task_run_id=task_run.id,
**models.TaskRunState.fields_from_state(
Pending(message="Task run created")
),
).first({"id", "version", "state", "serialized_state"})

if task_run:
return dict(
id=task_run.id,
version=task_run.version,
state=task_run.state,
serialized_state=task_run.serialized_state,
)
for task_run in stateless_runs
]
await models.TaskRunState().insert_many(task_run_states)

# return the task run ids
return [task_run.id for task_run in task_runs]
# if it isn't found, add it to the DB
task = await models.Task.where(id=task_id).first({"cache_key", "tenant_id"})
if not task:
raise ValueError("Invalid task ID")

db_task_run = models.TaskRun(
tenant_id=task.tenant_id,
flow_run_id=flow_run_id,
task_id=task_id,
map_index=map_index,
cache_key=task.cache_key,
version=0,
)

db_task_run_state = models.TaskRunState(
tenant_id=task.tenant_id,
state="Pending",
timestamp=pendulum.now(),
message="Task run created",
serialized_state=Pending(message="Task run created").serialize(),
)

db_task_run.states = [db_task_run_state]
run = await db_task_run.insert(
on_conflict=dict(
constraint="task_run_unique_identifier_key",
update_columns=["cache_key"],
),
selection_set={"returning": {"id"}},
)

return dict(
id=run.returning.id,
version=db_task_run.version,
state="Pending",
serialized_state=db_task_run.serialized_state,
)


@register_api("runs.update_flow_run_heartbeat")
Expand Down
79 changes: 79 additions & 0 deletions tests/api/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,85 @@ async def test_idempotency_key_is_scoped_to_version_group_id(
assert flow_run_id_1 != flow_run_id_3


class TestGetOrCreateTaskRunInfo:
async def test_get_or_create_task_run_info_hits_db(
self, tenant_id, flow_run_id, task_id
):
task_run = models.TaskRun(
id=str(uuid.uuid4()),
tenant_id=tenant_id,
flow_run_id=flow_run_id,
task_id=task_id,
map_index=12,
version=17,
state="Success",
serialized_state=dict(message="hi"),
)
await task_run.insert()

task_run_info = await api.runs.get_or_create_task_run_info(
flow_run_id=flow_run_id, task_id=task_id, map_index=task_run.map_index
)

assert task_run_info["id"] == task_run.id
assert task_run_info["version"] == task_run.version
assert task_run_info["state"] == task_run.state
assert task_run_info["serialized_state"] == task_run.serialized_state

async def test_get_or_create_task_run_info_inserts_into_db(
self, flow_run_id, task_id
):
assert not await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
"map_index": {"_eq": 12},
}
).first({"id"})

task_run_info = await api.runs.get_or_create_task_run_info(
flow_run_id=flow_run_id, task_id=task_id, map_index=12
)

task_run = await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
"map_index": {"_eq": 12},
}
).first({"id"})

assert task_run_info["id"] == task_run.id

task_run_state = await models.TaskRunState.where(
{
"task_run_id": {"_eq": task_run.id},
}
).first({"task_run_id", "state"})

assert task_run_info["id"] == task_run_state.task_run_id
assert task_run_info["state"] == task_run_state.state

async def test_properly_inserts_run_and_state(
self, tenant_id, flow_run_id, task_id
):
task_run_info = await api.runs.get_or_create_task_run_info(
flow_run_id=flow_run_id, task_id=task_id, map_index=12
)

task_run = await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
"map_index": {"_eq": 12},
}
).first({"id": True, "states": {"state", "task_run_id"}})
assert task_run.id == task_run_info["id"]
assert len(task_run.states) == 1
assert task_run.states[0].state == "Pending"
assert task_run.states[0].task_run_id == task_run_info["id"]


class TestGetTaskRunInfo:
async def test_task_run(self, flow_run_id, task_id):
tr_id = await api.runs.get_or_create_task_run(
Expand Down

0 comments on commit 87e5d96

Please sign in to comment.