Skip to content

Commit

Permalink
bugfix: make streaming spans last for the entire duration of the stre…
Browse files Browse the repository at this point in the history
…am (#120)

Why
===

The streaming spans were immediately finishing because we were only
tracing the construction of the `AsyncIterator` and not the full
iteration.

What changed
============

- loop and yield each message of the streaming procedures so the span
doesn't finish until the async iterator is disposed.

Test plan
=========

- Streaming procedure spans should now last the full duration of the
stream procedure
  • Loading branch information
cbrewster authored Nov 22, 2024
1 parent dea76d6 commit 4ccbd27
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
10 changes: 6 additions & 4 deletions replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@ async def send_subscription(
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("subscription", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return session.send_subscription(
async for msg in session.send_subscription(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)
):
yield msg

async def send_stream(
self,
Expand All @@ -128,7 +129,7 @@ async def send_stream(
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("stream", service_name, procedure_name):
session = await self._transport.get_or_create_session()
return session.send_stream(
async for msg in session.send_stream(
service_name,
procedure_name,
init,
Expand All @@ -137,7 +138,8 @@ async def send_stream(
request_serializer,
response_deserializer,
error_deserializer,
)
):
yield msg


@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ async def {name}(
self,
input: {render_type_expr(input_type)},
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_subscription(
return self.client.send_subscription(
{repr(schema_name)},
{repr(name)},
input,
Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def {name}(
init: {render_type_expr(init_type)},
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_stream(
return self.client.send_stream(
{repr(schema_name)},
{repr(name)},
init,
Expand All @@ -1053,7 +1053,7 @@ async def {name}(
self,
inputStream: AsyncIterable[{render_type_expr(input_type)}],
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
return await self.client.send_stream(
return self.client.send_stream(
{repr(schema_name)},
{repr(name)},
None,
Expand Down
21 changes: 11 additions & 10 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async def upload_data() -> AsyncGenerator[str, None]:
serialize_request,
serialize_request,
deserialize_response,
deserialize_response,
) # type: ignore
deserialize_error,
)
assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3"


Expand All @@ -58,8 +58,8 @@ async def upload_data() -> AsyncGenerator[str, None]:
serialize_request,
serialize_request,
deserialize_response,
deserialize_response,
) # type: ignore
deserialize_error,
)
assert response == "Uploaded: Initial Data" + (", Data" * iterations)


Expand All @@ -77,21 +77,22 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
None,
serialize_request,
deserialize_response,
deserialize_response,
) # type: ignore
deserialize_error,
)
assert response == "Uploaded: "


@pytest.mark.asyncio
async def test_subscription_method(client: Client) -> None:
async for response in await client.send_subscription(
async for response in client.send_subscription(
"test_service",
"subscription_method",
"Bob",
serialize_request,
deserialize_response,
deserialize_error,
):
assert isinstance(response, str)
assert "Subscription message" in response


Expand All @@ -103,7 +104,7 @@ async def stream_data() -> AsyncGenerator[str, None]:
yield "Stream 3"

responses = []
async for response in await client.send_stream(
async for response in client.send_stream(
"test_service",
"stream_method",
"Initial Stream Data",
Expand All @@ -130,7 +131,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]:
yield "unreachable"

responses = []
async for response in await client.send_stream(
async for response in client.send_stream(
"test_service",
"stream_method",
None,
Expand Down Expand Up @@ -167,7 +168,7 @@ async def stream_data() -> AsyncGenerator[str, None]:
deserialize_error,
)
)
stream_task = await client.send_stream(
stream_task = client.send_stream(
"test_service",
"stream_method",
"Initial Stream Data",
Expand Down

0 comments on commit 4ccbd27

Please sign in to comment.