Skip to content

Commit

Permalink
Add new get_task_run_info route for Cloud API compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw committed Nov 29, 2020
1 parent a3626c8 commit 8f6e128
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
23 changes: 22 additions & 1 deletion src/prefect_server/graphql/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,34 @@

import prefect
from prefect import api
from prefect_server.database import postgres
from prefect_server.database import models, postgres
from prefect_server.utilities import context
from prefect_server.utilities.graphql import mutation, query

state_schema = prefect.serialization.state.StateSchema()


@query.field("get_task_run_info")
async def resolve_get_task_run_info(
obj: Any, info: GraphQLResolveInfo, task_run_id: str
) -> dict:
"""
Retrieve details about a task run.
"""
task_run = await models.TaskRun.where(id=task_run_id).first(
{"version", "serialized_state", "state"}
)
if not task_run:
raise ValueError("Invalid task run ID")

return {
"version": task_run.version,
"serialized_state": task_run.serialized_state,
"state": task_run.state,
"id": task_run_id,
}


@query.field("mapped_children")
async def resolve_mapped_children(
obj: Any, info: GraphQLResolveInfo, task_run_id: str
Expand Down
16 changes: 15 additions & 1 deletion src/prefect_server/graphql/schema/runs.graphql
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
extend type Query {
mapped_children(task_run_id: UUID!): mapped_children_payload
mapped_children(task_run_id: UUID!): mapped_children_payload

"""
Given a task run ID, retrieve current task run info
"""
get_task_run_info(
task_run_id: UUID!
): task_run_info_payload

}

Expand Down Expand Up @@ -127,6 +134,13 @@ type task_run_id_payload {
id: UUID
}

type task_run_info_payload {
id: UUID
version: Int
serialized_state: JSON
state: String
}

type get_or_create_mapped_task_run_children_payload {
ids: [UUID!]
}
Expand Down
55 changes: 54 additions & 1 deletion tests/graphql/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import prefect
from prefect import api, models
from prefect.engine.state import Pending, Scheduled
from prefect.engine.state import Pending, Scheduled, Success


class TestCreateFlowRun:
Expand Down Expand Up @@ -167,6 +167,59 @@ async def test_create_flow_run_without_parameters_raises_error(
assert "Required parameters" in result.errors[0].message


class TestGetTaskRunInfo:
query = """
query($task_run_id: UUID!) {
get_task_run_info(task_run_id: $task_run_id) {
id
serialized_state
state
version
}
}
"""

async def test_get_task_run_info(
self,
run_query,
task_run_id,
):
result = await run_query(
query=self.query,
variables=dict(task_run_id=task_run_id),
)

output = result.data.get_task_run_info
assert output.id == task_run_id
assert output.version == 1
assert output.serialized_state.type == "Pending"
assert output.state == "Pending"

await api.states.set_task_run_state(task_run_id, state=Success("hi"))

result = await run_query(
query=self.query,
variables=dict(task_run_id=task_run_id),
)

assert result.data.get_task_run_info.version == 2
assert result.data.get_task_run_info.serialized_state.type == "Success"
assert result.data.get_task_run_info.state == "Success"
assert result.data.get_task_run_info.serialized_state.message == "hi"

async def test_get_task_run_info_handles_bad_ids(
self,
run_query,
):
result = await run_query(
query=self.query,
variables=dict(task_run_id=str(uuid.uuid4())),
)

assert result.errors[0].message == "Invalid task run ID"
assert result.data.get_task_run_info is None


class TestGetOrCreateTaskRun:
mutation = """
mutation($input: get_or_create_task_run_input!) {
Expand Down

0 comments on commit 8f6e128

Please sign in to comment.