diff --git a/anyio/_networking.py b/anyio/_networking.py index 55dea149..8a54fb88 100644 --- a/anyio/_networking.py +++ b/anyio/_networking.py @@ -143,11 +143,13 @@ async def sendto(self, data: bytes, addr, *, flags: int = 0) -> int: return self._raw_socket.sendto(data, flags, addr) async def sendall(self, data: bytes, *, flags: int = 0) -> None: - to_send = len(data) - while to_send > 0: + offset = 0 + total = len(data) + buffer = memoryview(data) + while offset < total: await self._check_cancelled() try: - sent = self._raw_socket.send(data, flags) + offset += self._raw_socket.send(buffer[offset:], flags) except (BlockingIOError, ssl.SSLWantWriteError): await self._wait_writable() except ssl.SSLWantReadError: @@ -155,8 +157,6 @@ async def sendall(self, data: bytes, *, flags: int = 0) -> None: except ssl.SSLEOFError: self._raw_socket.close() raise - else: - to_send -= sent async def start_tls(self, context: ssl.SSLContext, server_hostname: Optional[str] = None, diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index b77e9ca4..ff796a9a 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -3,6 +3,10 @@ Version history This library adheres to `Semantic Versioning `_. +**UNRELEASED** + +- Fixed mishandling of large buffers by ``BaseSocket.sendall()`` + **1.0.0b1** - Initial release diff --git a/tests/test_networking.py b/tests/test_networking.py index dad7bad5..8432a507 100644 --- a/tests/test_networking.py +++ b/tests/test_networking.py @@ -72,6 +72,23 @@ async def server(): assert response == b'blahbleh' + @pytest.mark.anyio + async def test_send_large_buffer(self): + async def server(): + async with await stream_server.accept() as stream: + await stream.send_all(buffer) + + buffer = b'\xff' * 1024 # should exceed the maximum kernel send buffer size + async with create_task_group() as tg: + async with await create_tcp_server(interface='localhost') as stream_server: + await tg.spawn(server) + async with await connect_tcp('localhost', stream_server.port) as client: + response = await client.receive_exactly(len(buffer)) + with pytest.raises(IncompleteRead): + await client.receive_exactly(1) + + assert response == buffer + @pytest.mark.parametrize('method_name, params', [ ('receive_until', [b'\n', 100]), ('receive_exactly', [5])