Skip to content

Commit

Permalink
Addressing pyright-raised type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Nov 29, 2024
1 parent 360f994 commit ba64db8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
13 changes: 5 additions & 8 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None:
logger.warning("unknown trace propagation key", extra={"key": key})


class GrpcContext(grpc.aio.ServicerContext):
class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]):
"""Represents a gRPC-compatible ServicerContext for River interop."""

def __init__(self, peer: str) -> None:
Expand Down Expand Up @@ -229,9 +229,8 @@ async def wrapped(
input: Channel[Any],
output: Channel[Any],
) -> None:
context = None
context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer)
try:
context = GrpcContext(peer)
request = request_deserializer(await input.get())
response = method(request, context)
if isinstance(response, Awaitable):
Expand Down Expand Up @@ -287,9 +286,8 @@ async def wrapped(
input: Channel[Any],
output: Channel[Any],
) -> None:
context = None
context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer)
try:
context = GrpcContext(peer)
request = request_deserializer(await input.get())
iterator = method(request, context)
if isinstance(iterator, AsyncIterable):
Expand Down Expand Up @@ -349,7 +347,7 @@ async def wrapped(
) -> None:
task_manager = BackgroundTaskManager()
try:
context = GrpcContext(peer)
context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer)
request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE)

async def _convert_inputs() -> None:
Expand Down Expand Up @@ -426,9 +424,8 @@ async def wrapped(
output: Channel[Any],
) -> None:
task_manager = BackgroundTaskManager()
context = None
context: GrpcContext[RequestType, ResponseType] = GrpcContext(peer)
try:
context = GrpcContext(peer)
request: Channel[RequestType] = Channel(MAX_MESSAGE_BUFFER_SIZE)

async def _convert_inputs() -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_message_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def mock_transport_message(seq: int) -> TransportMessage:
seq=seq,
id="test",
ack=0,
from_="test",
from_="test", # type: ignore
to="test",
streamId="test",
controlFlags=0,
payload=0,
model_config={},
model_config={}, # type: ignore
)


Expand Down

0 comments on commit ba64db8

Please sign in to comment.