Skip to content

Commit

Permalink
Remove unused route and add new get or create task run info route for…
Browse files Browse the repository at this point in the history
… Cloud API parity
  • Loading branch information
cicdw committed Nov 29, 2020
1 parent 87e5d96 commit 19733ac
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 239 deletions.
29 changes: 17 additions & 12 deletions src/prefect_server/graphql/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ async def resolve_get_task_run_info(
}


@mutation.field("get_or_create_task_run_info")
async def resolve_get_or_create_task_run_info(
obj: Any, info: GraphQLResolveInfo, input: dict
) -> dict:
info = await api.runs.get_or_create_task_run_info(
flow_run_id=input["flow_run_id"],
task_id=input["task_id"],
map_index=input.get("map_index"),
)
return {
"id": info["id"],
"version": info["version"],
"state": info["state"],
"serialized_state": info["serialized_state"],
}


@query.field("mapped_children")
async def resolve_mapped_children(
obj: Any, info: GraphQLResolveInfo, task_run_id: str
Expand Down Expand Up @@ -152,18 +169,6 @@ async def resolve_get_or_create_task_run(
}


@mutation.field("get_or_create_mapped_task_run_children")
async def resolve_get_or_create_mapped_task_run_children(
obj: Any, info: GraphQLResolveInfo, input: dict
) -> List[dict]:
task_runs = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=input["flow_run_id"],
task_id=input["task_id"],
max_map_index=input["max_map_index"],
)
return {"ids": task_runs}


@mutation.field("delete_flow_run")
async def resolve_delete_flow_run(
obj: Any, info: GraphQLResolveInfo, input: dict
Expand Down
23 changes: 17 additions & 6 deletions src/prefect_server/graphql/schema/runs.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ extend type Mutation {
input: get_or_create_task_run_input!
): task_run_id_payload

"Gets or creates all mapped task run children for a parent task run."
get_or_create_mapped_task_run_children(
input: get_or_create_mapped_task_run_children_input!
): get_or_create_mapped_task_run_children_payload
"""
Given a flow run, task, and map index, retrieve the corresponding task run id
"""
get_or_create_task_run_info(
input: get_or_create_task_run_info_input!
): get_or_create_task_run_info_payload

"Update a flow run's heartbeat. This indicates the flow run is alive and is called automatically by Prefect Core."
update_flow_run_heartbeat(
Expand Down Expand Up @@ -102,6 +104,12 @@ input get_or_create_task_run_input {
map_index: Int
}

input get_or_create_task_run_info_input {
flow_run_id: UUID!
task_id: UUID!
map_index: Int
}

input get_or_create_mapped_task_run_children_input {
flow_run_id: UUID!
task_id: UUID!
Expand Down Expand Up @@ -141,8 +149,11 @@ type task_run_info_payload {
state: String
}

type get_or_create_mapped_task_run_children_payload {
ids: [UUID!]
type get_or_create_task_run_info_payload {
id: UUID
version: Int
state: String
serialized_state: JSON
}

type runs_in_queue_payload {
Expand Down
150 changes: 0 additions & 150 deletions tests/api/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,156 +663,6 @@ async def test_task_run_doesnt_insert_state_if_tr_already_exists(
assert new_task_run_state_count == task_run_state_count


class TestGetOrCreateMappedChildren:
async def test_get_or_create_mapped_children_creates_children(
self, flow_id, flow_run_id
):
# get a task from the flow
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first({"id"})
task_runs = await models.TaskRun.where({"task_id": {"_eq": task.id}}).get()

mapped_children = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=flow_run_id, task_id=task.id, max_map_index=10
)
# confirm 11 children were returned as a result (indices 0, through 10)
assert len(mapped_children) == 11
# confirm those 11 children are in the DB
assert len(task_runs) + 11 == len(
await models.TaskRun.where({"task_id": {"_eq": task.id}}).get()
)
# confirm that those 11 children have api.states and the map indices are ordered
map_indices = []
for child in mapped_children:
task_run = await models.TaskRun.where(id=child).first(
{
"map_index": True,
with_args(
"states",
{"order_by": {"version": EnumValue("desc")}, "limit": 1},
): {"id"},
}
)
map_indices.append(task_run.map_index)
assert task_run.states[0] is not None
assert map_indices == sorted(map_indices)

async def test_get_or_create_mapped_children_retrieves_children(
self, flow_id, flow_run_id
):
# get a task from the flow
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first(
{"id", "cache_key"}
)

# create some mapped children
task_run_ids = []
for i in range(11):
task_run_ids.append(
await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=i,
cache_key=task.cache_key,
).insert()
)
# retrieve those mapped children
mapped_children = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=flow_run_id, task_id=task.id, max_map_index=10
)
# confirm we retrieved 11 mapped children (0 through 10)
assert len(mapped_children) == 11
# confirm those 11 children are the task api.runs we created earlier and that they're in order
map_indices = []
for child in mapped_children:
task_run = await models.TaskRun.where(id=child).first({"map_index"})
map_indices.append(task_run.map_index)
assert child in task_run_ids
assert map_indices == sorted(map_indices)

async def test_get_or_create_mapped_children_does_not_retrieve_parent(
self, flow_id, flow_run_id
):
# get a task from the flow
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first(
{"id", "cache_key"}
)
# create a parent and its mapped children
for i in range(3):
await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=i,
cache_key=task.cache_key,
).insert()

# retrieve those mapped children
mapped_children = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=flow_run_id, task_id=task.id, max_map_index=2
)
# confirm we retrieved 3 mapped children (0, 1, and 2)
assert len(mapped_children) == 3
# but not the parent
for child in mapped_children:
task_run = await models.TaskRun.where(id=child).first({"map_index"})
assert task_run.map_index > -1

async def test_get_or_create_mapped_children_handles_partial_children(
self, flow_id, flow_run_id
):
# get a task from the flow
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first(
{"id", "cache_key"}
)

# create a few mapped children
await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=3,
cache_key=task.cache_key,
).insert()
stateful_child = await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=6,
cache_key=task.cache_key,
states=[
models.TaskRunState(
**models.TaskRunState.fields_from_state(
Pending(message="Task run created")
),
)
],
).insert()

# retrieve mapped children
mapped_children = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=flow_run_id, task_id=task.id, max_map_index=10
)
assert len(mapped_children) == 11
map_indices = []
# confirm each of the mapped children has a state and is ordered properly
for child in mapped_children:
task_run = await models.TaskRun.where(id=child).first(
{
"map_index": True,
with_args(
"states",
{"order_by": {"version": EnumValue("desc")}, "limit": 1},
): {"id"},
}
)
map_indices.append(task_run.map_index)
assert task_run.states[0] is not None
assert map_indices == sorted(map_indices)

# confirm the one child created with a state only has the one state
child_states = await models.TaskRunState.where(
{"task_run_id": {"_eq": stateful_child}}
).get()
assert len(child_states) == 1


class TestUpdateFlowRunHeartbeat:
async def test_update_heartbeat(self, flow_run_id):
dt = pendulum.now()
Expand Down
131 changes: 60 additions & 71 deletions tests/graphql/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,66 @@ async def test_create_flow_run_without_parameters_raises_error(
assert "Required parameters" in result.errors[0].message


class TestGetOrCreateTaskRunInfo:
mutation = """
mutation($input: get_or_create_task_run_info_input!) {
get_or_create_task_run_info(input: $input) {
id
version
state
serialized_state
}
}
"""

async def test_get_existing_task_run_id(
self, run_query, task_run_id, task_id, flow_run_id
):
result = await run_query(
query=self.mutation,
variables=dict(input=dict(flow_run_id=flow_run_id, task_id=task_id)),
)

task_run = await models.TaskRun.where(id=task_run_id).first(
{"id", "version", "state", "serialized_state"}
)

assert result.data.get_or_create_task_run_info.id == task_run.id
assert result.data.get_or_create_task_run_info.version == task_run.version
assert result.data.get_or_create_task_run_info.state == task_run.state
assert (
result.data.get_or_create_task_run_info.serialized_state
== task_run.serialized_state
)

async def test_get_new_task_run_id(
self, run_query, task_run_id, task_id, flow_run_id
):
assert not await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
"map_index": {"_eq": 12},
}
).first({"id"})

result = await run_query(
query=self.mutation,
variables=dict(
input=dict(flow_run_id=flow_run_id, task_id=task_id, map_index=12)
),
)

task_run = await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
"map_index": {"_eq": 12},
}
).first({"id"})
assert task_run.id == result.data.get_or_create_task_run_info.id


class TestGetTaskRunInfo:
query = """
query($task_run_id: UUID!) {
Expand Down Expand Up @@ -263,77 +323,6 @@ async def test_get_or_create_task_run_new_task_run(
)


class TestGetOrCreateMappedTaskRunChildren:
mutation = """
mutation($input: get_or_create_mapped_task_run_children_input!) {
get_or_create_mapped_task_run_children(input: $input) {
ids
}
}
"""

async def test_get_or_create_mapped_task_run_children(
self, run_query, flow_run_id, flow_id
):
# grab the task ID
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first({"id"})
result = await run_query(
query=self.mutation,
variables=dict(
input=dict(flow_run_id=flow_run_id, task_id=task.id, max_map_index=5)
),
)
# should have 6 children, indices 0-5
assert len(result.data.get_or_create_mapped_task_run_children.ids) == 6

async def test_get_or_create_mapped_task_run_children_with_partial_children(
self, run_query, flow_run_id, flow_id
):
task = await models.Task.where({"flow_id": {"_eq": flow_id}}).first({"id"})
# create a couple of children
preexisting_run_1 = await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=3,
cache_key=task.cache_key,
).insert()
preexisting_run_2 = await models.TaskRun(
flow_run_id=flow_run_id,
task_id=task.id,
map_index=6,
cache_key=task.cache_key,
states=[
models.TaskRunState(
**models.TaskRunState.fields_from_state(
Pending(message="Task run created")
),
)
],
).insert()
# call the route
result = await run_query(
query=self.mutation,
variables=dict(
input=dict(flow_run_id=flow_run_id, task_id=task.id, max_map_index=10)
),
)
mapped_children = result.data.get_or_create_mapped_task_run_children.ids
# should have 11 children, indices 0-10
assert len(mapped_children) == 11

# confirm the preexisting task runs were included in the results
assert preexisting_run_1 in mapped_children
assert preexisting_run_2 in mapped_children

# confirm the results are ordered
map_indices = []
for child in mapped_children:
map_indices.append(
(await models.TaskRun.where(id=child).first({"map_index"})).map_index
)
assert map_indices == sorted(map_indices)


class TestUpdateFlowRunHeartbeat:
mutation = """
mutation($input: update_flow_run_heartbeat_input!) {
Expand Down

0 comments on commit 19733ac

Please sign in to comment.