From ba64db8ddf0a7f67f2f21478d751475316951ee6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 28 Nov 2024 11:18:46 -0800 Subject: [PATCH] Addressing pyright-raised type issues --- replit_river/rpc.py | 13 +++++-------- tests/test_message_buffer.py | 4 ++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/replit_river/rpc.py b/replit_river/rpc.py index 53152c4..7dba07e 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -127,7 +127,7 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None: logger.warning("unknown trace propagation key", extra={"key": key}) -class GrpcContext(grpc.aio.ServicerContext): +class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]): """Represents a gRPC-compatible ServicerContext for River interop.""" def __init__(self, peer: str) -> None: @@ -229,9 +229,8 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request = request_deserializer(await input.get()) response = method(request, context) if isinstance(response, Awaitable): @@ -287,9 +286,8 @@ async def wrapped( input: Channel[Any], output: Channel[Any], ) -> None: - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request = request_deserializer(await input.get()) iterator = method(request, context) if isinstance(iterator, AsyncIterable): @@ -349,7 +347,7 @@ async def wrapped( ) -> None: task_manager = BackgroundTaskManager() try: - context = GrpcContext(peer) + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: @@ -426,9 +424,8 @@ async def wrapped( output: Channel[Any], ) -> None: task_manager = BackgroundTaskManager() - context = None + context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer) try: - context = GrpcContext(peer) request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE) async def _convert_inputs() -> None: diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py index 3c7c6a9..7d33037 100644 --- a/tests/test_message_buffer.py +++ b/tests/test_message_buffer.py @@ -11,12 +11,12 @@ def mock_transport_message(seq: int) -> TransportMessage: seq=seq, id="test", ack=0, - from_="test", + from_="test", # type: ignore to="test", streamId="test", controlFlags=0, payload=0, - model_config={}, + model_config={}, # type: ignore )