diff --git a/replit_river/rpc.py b/replit_river/rpc.py index 53152c4..7dba07e 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -127,7 +127,7 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None: logger.warning("unknown trace propagation key", extra={"key": key}) -class GrpcContext(grpc.aio.ServicerContext): +class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]): """Represents a gRPC-compatible ServicerContext for River interop.""" def __init__(self, peer: str) -> None: @@ -229,9 +229,8 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request = request_deserializer(await input.get()) response = method(request, context) if isinstance(response, Awaitable): @@ -287,9 +286,8 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request = request_deserializer(await input.get()) iterator = method(request, context) if isinstance(iterator, AsyncIterable): @@ -349,7 +347,7 @@ async def wrapped( ) -> None: task_manager = BackgroundTaskManager() try: - context = GrpcContext(peer) + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: @@ -426,9 +424,8 @@ async def wrapped( output: Channel[Any], ) -> None: task_manager = BackgroundTaskManager() - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index 3c7c6a9..7d33037 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -11,12 +11,12 @@ def mock_transport_message(seq: int) -> TransportMessage: seq=seq, id="test", ack=0, - from_="test", + from_="test", # type: ignore to="test", streamId="test", controlFlags=0, payload=0, - model_config={}, + model_config={}, # type: ignore )