diff --git a/replit_river/client.py b/replit_river/client.py index 8111c63..f151ced 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -106,14 +106,15 @@ async def send_subscription( ) -> AsyncIterator[Union[ResponseType, ErrorType]]: with _trace_procedure("subscription", service_name, procedure_name): session = await self._transport.get_or_create_session() - return session.send_subscription( + async for msg in session.send_subscription( service_name, procedure_name, request, request_serializer, response_deserializer, error_deserializer, - ) + ): + yield msg async def send_stream( self, @@ -128,7 +129,7 @@ async def send_stream( ) -> AsyncIterator[Union[ResponseType, ErrorType]]: with _trace_procedure("stream", service_name, procedure_name): session = await self._transport.get_or_create_session() - return session.send_stream( + async for msg in session.send_stream( service_name, procedure_name, init, @@ -137,7 +138,8 @@ async def send_stream( request_serializer, response_deserializer, error_deserializer, - ) + ): + yield msg @contextmanager diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index e718e80..c876611 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -953,7 +953,7 @@ async def {name}( self, input: {render_type_expr(input_type)}, ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return await self.client.send_subscription( + return self.client.send_subscription( {repr(schema_name)}, {repr(name)}, input, @@ -1029,7 +1029,7 @@ async def {name}( init: {render_type_expr(init_type)}, inputStream: AsyncIterable[{render_type_expr(input_type)}], ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return await self.client.send_stream( + return self.client.send_stream( {repr(schema_name)}, {repr(name)}, init, @@ -1053,7 +1053,7 @@ async def {name}( self, inputStream: AsyncIterable[{render_type_expr(input_type)}], ) -> AsyncIterator[{render_type_expr(output_or_error_type)}]: - return await self.client.send_stream( + return self.client.send_stream( {repr(schema_name)}, {repr(name)}, None, diff --git a/tests/test_communication.py b/tests/test_communication.py index c18d357..b9db72f 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -37,8 +37,8 @@ async def upload_data() -> AsyncGenerator[str, None]: serialize_request, serialize_request, deserialize_response, - deserialize_response, - ) # type: ignore + deserialize_error, + ) assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" @@ -58,8 +58,8 @@ async def upload_data() -> AsyncGenerator[str, None]: serialize_request, serialize_request, deserialize_response, - deserialize_response, - ) # type: ignore + deserialize_error, + ) assert response == "Uploaded: Initial Data" + (", Data" * iterations) @@ -77,14 +77,14 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]: None, serialize_request, deserialize_response, - deserialize_response, - ) # type: ignore + deserialize_error, + ) assert response == "Uploaded: " @pytest.mark.asyncio async def test_subscription_method(client: Client) -> None: - async for response in await client.send_subscription( + async for response in client.send_subscription( "test_service", "subscription_method", "Bob", @@ -92,6 +92,7 @@ async def test_subscription_method(client: Client) -> None: deserialize_response, deserialize_error, ): + assert isinstance(response, str) assert "Subscription message" in response @@ -103,7 +104,7 @@ async def stream_data() -> AsyncGenerator[str, None]: yield "Stream 3" responses = [] - async for response in await client.send_stream( + async for response in client.send_stream( "test_service", "stream_method", "Initial Stream Data", @@ -130,7 +131,7 @@ async def stream_data(enabled: bool = False) -> AsyncGenerator[str, None]: yield "unreachable" responses = [] - async for response in await client.send_stream( + async for response in client.send_stream( "test_service", "stream_method", None, @@ -167,7 +168,7 @@ async def stream_data() -> AsyncGenerator[str, None]: deserialize_error, ) ) - stream_task = await client.send_stream( + stream_task = client.send_stream( "test_service", "stream_method", "Initial Stream Data",