From b38809230b1655937c9586a775badeae6b73ab8d Mon Sep 17 00:00:00 2001 From: Micah Roberts Date: Fri, 5 Jan 2024 12:32:11 -0700 Subject: [PATCH] Tunneling network congestion and port print fix. --- keepercommander/commands/discoveryrotation.py | 4 +- .../commands/tunnel/port_forward/endpoint.py | 181 ++++++++++++------ unit-tests/pam-tunnel/test_private_tunnel.py | 50 ++--- 3 files changed, 145 insertions(+), 90 deletions(-) diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 5062327d4..a7cb1866b 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -2028,7 +2028,7 @@ def print_fail(): host = host + ":" if host else '' # Total length of the dynamic parts (endpoint name, host, and port) dynamic_length = \ - (len("| Endpoint : Listening on port: ") + len(convo_id) + len(host) + len(str(port))) + (len("| Endpoint : Listening on port: ") + len(convo_id) + len(host) + len(str(entrance.port))) # Dashed line adjusted to the length of the middle line dashed_line = '+' + '-' * dynamic_length + '+' @@ -2038,7 +2038,7 @@ def print_fail(): print( f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{convo_id}{bcolors.ENDC}' f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}' - f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}') + f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{entrance.port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}') print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}') print( f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}') diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index 751982bd4..8d2448f99 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -16,7 +16,8 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.utils import int_to_bytes -from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string, string_to_bytes +from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string, string_to_bytes, \ + bytes_to_int from keepercommander.commands.pam.pam_dto import GatewayActionWebRTCSession from keepercommander.commands.pam.router_helper import router_get_relay_access_creds, router_send_action_to_gateway @@ -26,15 +27,21 @@ logging.getLogger('aiortc').setLevel(logging.WARNING) logging.getLogger('aioice').setLevel(logging.WARNING) -#TODO add why 9 is the length of the protocol -BUFFER_TRUNCATION_THRESHOLD = 16000 - 9 # 16 Kbytes max https://viblast.com/blog/2015/2/5/webrtc-data-channel-message-size/ so we will use the max minus 9 bytes for the protocol + READ_TIMEOUT = 10 -CONTROL_MESSAGE_NO_LENGTH = 2 -CONNECTION_NO_LENGTH = DATA_LENGTH = 4 NONCE_LENGTH = 12 SYMMETRIC_KEY_LENGTH = RANDOM_LENGTH = 32 +MESSAGE_MAX = 5 + +# Protocol constants +CONTROL_MESSAGE_NO_LENGTH = 2 +CONNECTION_NO_LENGTH = DATA_LENGTH = 4 TERMINATOR = b';' +PROTOCOL_LENGTH = CONNECTION_NO_LENGTH + DATA_LENGTH + CONTROL_MESSAGE_NO_LENGTH + len(TERMINATOR) + +# WebRTC constants BUFFER_THRESHOLD = 134217728 * .90 # 16 MiB max https://viblast.com/blog/2015/2/25/webrtc-bufferedamount/ so we will use 14.4 MiB or 90% of the max, because in some cases if the max is reached the channel will close +BUFFER_TRUNCATION_THRESHOLD = 16000 - PROTOCOL_LENGTH # 16 Kbytes max https://viblast.com/blog/2015/2/5/webrtc-data-channel-message-size/ so we will use the max minus bytes for the protocol class ConnectionNotFoundException(Exception): @@ -172,7 +179,8 @@ async def signal_channel(self, kind: str): else: raise Exception(f'Invalid kind: {kind}') except socket.gaierror: - print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}") + print( + f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}") return except Exception as e: raise Exception(f'Error making WebRTC offer: {e}') @@ -404,6 +412,18 @@ async def close_webrtc_connection(self): """ +class ConnectionInfo: + def __init__(self, reader: Optional[asyncio.StreamReader], writer: Optional[asyncio.StreamWriter], + message_counter: int, ping_time: Optional[float], to_tunnel_task: Optional[asyncio.Task], + start_time: datetime): + self.reader = reader + self.writer = writer + self.message_counter = message_counter + self.ping_time = ping_time + self.to_tunnel_task = to_tunnel_task + self.start_time = start_time + + class TunnelEntrance: """ This class is used to forward data between a WebRTC connection and a connection to a target. @@ -412,37 +432,43 @@ class TunnelEntrance: Data is broken into three parts: connection number, [message number], and data message number is only used in control messages. (if the connection number is 0 then there is a message number) """ + def __init__(self, - host, # type: str - port, # type: int - endpoint_name, # type: str - pc, # type: WebRTCConnection - print_ready_event, # type: asyncio.Event - logger = None, # type: logging.Logger - connect_task = None, # type: asyncio.Task - kill_server_event = None # type: asyncio.Event - ): # type: (...) -> None + host, # type: str + port, # type: int + endpoint_name, # type: str + pc, # type: WebRTCConnection + print_ready_event, # type: asyncio.Event + logger=None, # type: logging.Logger + connect_task=None, # type: asyncio.Task + kill_server_event=None # type: asyncio.Event + ): # type: (...) -> None self.closing = False - self.ping_time = None self.to_local_task = None self._ping_attempt = 0 self.host = host self.server = None self.connection_no = 1 self.endpoint_name = endpoint_name - self.connections: Dict[int, Tuple[asyncio.StreamReader, asyncio.StreamWriter]] = {} + self.connections: Dict[int, ConnectionInfo] = { + 0: ConnectionInfo( + None, None, 0, None, None, datetime.now() + ) + } self._port = port self.logger = logger self.is_connected = True self.reader_task = asyncio.create_task(self.start_reader()) - self.to_tunnel_tasks = {} self.kill_server_event = kill_server_event self.pc = pc self.print_ready_event = print_ready_event self.connect_task = connect_task - self.connection_time = {} - async def send_to_web_rtc(self, data, connection_no=0): + @property + def port(self): + return self._port + + async def send_to_web_rtc(self, data): # TODO: figure out networking issue here if self.pc.is_data_channel_open(): try: @@ -455,8 +481,6 @@ async def send_to_web_rtc(self, data, connection_no=0): f"{self.pc.data_channel.bufferedAmount}{bcolors.ENDC}") await asyncio.sleep(sleep_count) sleep_count += .01 - self.logger.debug(f'Endpoint {self.endpoint_name}: buffer size: {self.pc.data_channel.bufferedAmount}' + - f', time since start: {datetime.now() - self.connection_time[connection_no]["start_time"]}' if connection_no > 0 else '') self.pc.send_message(data) # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) @@ -498,18 +522,44 @@ async def process_control_message(self, message_no, data): # type: (ControlMess f'{target_connection_no}') await self.close_connection(target_connection_no) elif message_no == ControlMessage.Pong: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received pong request') self._ping_attempt = 0 self.is_connected = True - if self.ping_time is not None: - time_now = time.perf_counter() - # from the time the ping was sent to the time the pong was received - latency = time_now - self.ping_time - self.logger.debug(f'Endpoint {self.endpoint_name}: Round trip latency: {latency} ms') - self.ping_time = None + if len(data) >= 0: + con_no = bytes_to_int(data) + if con_no in self.connections: + self.connections[con_no].message_counter = 0 + self.logger.debug(f'Endpoint {self.endpoint_name}: Received pong request') + if con_no != 0: + self.logger.debug(f'Endpoint {self.endpoint_name}: Received ACK for {con_no}') + if self.connections[con_no].ping_time is not None: + time_now = time.perf_counter() + # from the time the ping was sent to the time the pong was received + latency = time_now - self.connections[con_no].ping_time + self.logger.debug(f'Endpoint {self.endpoint_name}: Round trip latency: {latency} ms') + self.connections[con_no].ping_time = None + else: + self.logger.debug(f'Endpoint {self.endpoint_name}: Received pong request') + if self.connections[0].ping_time is not None: + time_now = time.perf_counter() + # from the time the ping was sent to the time the pong was received + latency = time_now - self.connections[0].ping_time + self.logger.debug(f'Endpoint {self.endpoint_name}: Round trip latency: {latency} ms') + self.connections[0].ping_time = None + elif message_no == ControlMessage.Ping: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received ping request') - await self.send_control_message(ControlMessage.Pong) + if len(data) >= 0: + con_no = bytes_to_int(data) + if con_no in self.connections: + + if con_no == 0: + self.logger.debug(f'Endpoint {self.endpoint_name}: Received ping request') + else: + self.logger.debug(f'Endpoint {self.endpoint_name}: Received Ping for {con_no}') + await self.send_control_message(ControlMessage.Pong, int_to_bytes(con_no)) + else: + self.logger.error(f'Endpoint {self.endpoint_name}: Connection {con_no} not found') + else: + self.logger.error(f'Endpoint {self.endpoint_name}: Connection not found') elif message_no == ControlMessage.ConnectionOpened: if len(data) >= CONNECTION_NO_LENGTH: if len(data) > CONNECTION_NO_LENGTH: @@ -519,14 +569,14 @@ async def process_control_message(self, message_no, data): # type: (ControlMess self.logger.debug(f"Endpoint {self.endpoint_name}: Starting reader for connection " f"{connection_no}") try: - self.to_tunnel_tasks[connection_no] = asyncio.create_task( + self.connections[connection_no].to_tunnel_task = asyncio.create_task( self.forward_data_to_tunnel(connection_no)) # From current connection to WebRTC connection self.logger.debug( f"Endpoint {self.endpoint_name}: Started reader for connection {connection_no}") except ConnectionNotFoundException as e: self.logger.error(f"Endpoint {self.endpoint_name}: Connection {connection_no} not found: {e}") except Exception as e: - self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding data: {e}") + self.logger.error(f"Endpoint {self.endpoint_name}: Error in forwarding data task: {e}") else: self.logger.error(f"Endpoint {self.endpoint_name}: Invalid open connection message") else: @@ -548,11 +598,11 @@ async def forward_data_to_local(self): break while len(buff) >= CONNECTION_NO_LENGTH + DATA_LENGTH: connection_no = int.from_bytes(buff[:CONNECTION_NO_LENGTH], byteorder='big') - length = int.from_bytes(buff[CONNECTION_NO_LENGTH:CONNECTION_NO_LENGTH+DATA_LENGTH], + length = int.from_bytes(buff[CONNECTION_NO_LENGTH:CONNECTION_NO_LENGTH + DATA_LENGTH], byteorder='big') if len(buff) >= CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR): if buff[CONNECTION_NO_LENGTH + DATA_LENGTH + length: - CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)] != TERMINATOR: + CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)] != TERMINATOR: self.logger.warning(f'Endpoint {self.endpoint_name}: Invalid terminator') # if we don't have a valid terminator then we don't know where the message ends or begins self.kill_server_event.set() @@ -574,12 +624,11 @@ async def forward_data_to_local(self): f"{connection_no}") continue - _, con_writer = self.connections[connection_no] try: self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding data to " f"local for connection {connection_no} ({len(send_data)})") - con_writer.write(send_data) - await con_writer.drain() + self.connections[connection_no].writer.write(send_data) + await self.connections[connection_no].writer.drain() # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) except Exception as ex: @@ -606,8 +655,7 @@ async def forward_data_to_local(self): raise et self.logger.debug(f'Endpoint {self.endpoint_name}: Tunnel reader timed out') self.logger.debug(f'Endpoint {self.endpoint_name}: Send ping request') - self.ping_time = time.perf_counter() - await self.send_control_message(ControlMessage.Ping) + await self.send_control_message(ControlMessage.Ping, ) self._ping_attempt += 1 continue self.pc.web_rtc_queue.task_done() @@ -637,7 +685,7 @@ async def forward_data_to_local(self): self.logger.debug(f"Endpoint {self.endpoint_name}: Closing tunnel") await self.stop_server() - async def start_reader(self): # type: () -> None + async def start_reader(self): # type: () -> None """ Transfer data from WebRTC connection to local connections. """ @@ -647,8 +695,7 @@ async def start_reader(self): # type: () -> None self.to_local_task = asyncio.create_task(self.forward_data_to_local()) # Send hello world open connection message - self.ping_time = time.perf_counter() - await self.send_control_message(ControlMessage.Ping) + await self.send_control_message(ControlMessage.Ping, ) self.logger.debug(f"Endpoint {self.endpoint_name}: Sent ping message to WebRTC connection") except asyncio.CancelledError: pass @@ -670,23 +717,39 @@ async def forward_data_to_tunnel(self, con_no): c = self.connections.get(con_no) if c is None or not self.is_connected: break - reader, _ = c try: - data = await reader.read(BUFFER_TRUNCATION_THRESHOLD) + data = await c.reader.read(BUFFER_TRUNCATION_THRESHOLD) self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding {len(data)} " f"bytes to tunnel for connection {con_no}") if not data: self.logger.debug(f"Endpoint {self.endpoint_name}: Connection {con_no} no data") break if isinstance(data, bytes): - if reader.at_eof() and len(data) == 0: + if c.reader.at_eof() and len(data) == 0: # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) continue else: buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR - await self.send_to_web_rtc(buffer, con_no) + await self.send_to_web_rtc(buffer) + + self.logger.debug( + f'Endpoint {self.endpoint_name}: buffer size: {self.pc.data_channel.bufferedAmount}' + + f', time since start: {datetime.now() - c.start_time}') + + c.message_counter += 1 + if c.message_counter >= MESSAGE_MAX and self.pc.data_channel.bufferedAmount > BUFFER_TRUNCATION_THRESHOLD: + c.ping_time = time.perf_counter() + await self.send_control_message(ControlMessage.Ping, int_to_bytes(con_no)) + self._ping_attempt += 1 + wait_count = 0 + while c.message_counter >= MESSAGE_MAX: + await asyncio.sleep(wait_count) + wait_count += .1 + elif c.message_counter >= MESSAGE_MAX and self.pc.data_channel.bufferedAmount <= BUFFER_TRUNCATION_THRESHOLD: + c.message_counter = 0 + else: # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) @@ -711,9 +774,7 @@ async def handle_connection(self, reader, writer): # type: (asyncio.StreamReade """ connection_no = self.connection_no self.connection_no += 1 - self.connections[connection_no] = (reader, writer) - self.connection_time[connection_no] = {"start_time": datetime.now()} - + self.connections[connection_no] = ConnectionInfo(reader, writer, 0, None, None, datetime.now()) self.logger.debug(f"Endpoint {self.endpoint_name}: Created local connection {connection_no}") # Send open connection message with con_no. this is required to be sent to start the connection @@ -800,28 +861,22 @@ async def close_connection(self, connection_no): except Exception as ex: self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception sending Close connection {ex}') - if connection_no in self.connections: - reader, writer = self.connections[connection_no] - writer.close() + if connection_no in self.connections and connection_no != 0: + self.connections[connection_no].writer.close() # Wait for it to actually close. try: - await asyncio.wait_for(writer.wait_closed(), timeout=5.0) + await asyncio.wait_for(self.connections[connection_no].writer.wait_closed(), timeout=5.0) except asyncio.TimeoutError: self.logger.warning( f"Endpoint {self.endpoint_name}: Timed out while trying to close connection " f"{connection_no}") if connection_no in self.connections: + try: + self.connections[connection_no].to_tunnel_task.cancel() + except Exception as ex: + self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling tasks {ex}') del self.connections[connection_no] self.logger.info(f"Endpoint {self.endpoint_name}: Closed connection {connection_no}") else: self.logger.info(f"Endpoint {self.endpoint_name}: Connection {connection_no} not found") - if connection_no in self.to_tunnel_tasks: - try: - self.to_tunnel_tasks[connection_no].cancel() - except Exception as ex: - self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling tasks {ex}') - del self.to_tunnel_tasks[connection_no] - self.logger.info(f"Endpoint {self.endpoint_name}: Tasks closed for connection {connection_no}") - else: - self.logger.info(f"Endpoint {self.endpoint_name}: Tasks for {connection_no} not found") diff --git a/unit-tests/pam-tunnel/test_private_tunnel.py b/unit-tests/pam-tunnel/test_private_tunnel.py index 4bfba7bbc..477b78726 100644 --- a/unit-tests/pam-tunnel/test_private_tunnel.py +++ b/unit-tests/pam-tunnel/test_private_tunnel.py @@ -1,6 +1,7 @@ import sys import unittest from unittest import mock +from unittest.mock import call if sys.version_info >= (3, 8): import asyncio @@ -13,7 +14,8 @@ from keepercommander.commands.tunnel.port_forward.endpoint import (TunnelEntrance, ControlMessage, CONTROL_MESSAGE_NO_LENGTH, CONNECTION_NO_LENGTH, ConnectionNotFoundException, - TERMINATOR, DATA_LENGTH, WebRTCConnection) + TERMINATOR, DATA_LENGTH, WebRTCConnection, + ConnectionInfo) from test_pam_tunnel import new_private_key # Only define the class if Python version is 3.8 or higher @@ -74,7 +76,8 @@ async def test_send_control_message(self): expected_data += optional_data + TERMINATOR # Assertions - mock_send_message.assert_called_once_with(expected_data) + + self.assertTrue(call(expected_data) in mock_send_message.call_args_list) async def test_send_control_message_with_error(self): # Initialize self.pte.tls_writer with a mock object @@ -95,12 +98,12 @@ async def test_send_control_message_with_error(self): expected_error_message = (f"Endpoint {self.pte.endpoint_name}: Error sending message: Mocked Exception") # Assertions - self.pte.logger.error.assert_called_once_with(expected_error_message) + self.pte.logger.error.assert_called_with(expected_error_message) async def test_forward_data_to_local_normal(self): await self.set_queue_side_effect() connection = (None, mock.MagicMock(spec=asyncio.StreamWriter)) - self.pte.connections = {1: connection} + self.pte.connections[1] = ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now()) self.pte.logger = mock.MagicMock() self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) self.pte.stop_server = mock.AsyncMock() @@ -110,19 +113,19 @@ async def test_forward_data_to_local_normal(self): mock_close.side_effect = mock.MagicMock(spec=asyncio.Task) await self.pte.forward_data_to_local() - self.assertTrue(len(self.pte.connections) == 1) - self.pte.connections[1][1].write.assert_called_with(b'some_data') - self.pte.connections[1][1].drain.assert_called_once() + self.assertTrue(len(self.pte.connections) == 2) + self.pte.connections[1].writer.write.assert_called_with(b'some_data') + self.pte.connections[1].writer.drain.assert_called_once() async def test_forward_data_to_local_error(self): await self.set_queue_side_effect() connection = (None, mock.MagicMock(spec=asyncio.StreamWriter)) - self.pte.connections = {1: connection} + self.pte.connections[1] = ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now()) self.pte.logger = mock.MagicMock() self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) self.pte.kill_server_event.is_set.side_effect = [False, False, False, False, True, True] self.pte.pc.closed = False - self.pte.connections[1][1].write.side_effect = Exception("Some error") + self.pte.connections[1].writer.write.side_effect = Exception("Some error") await self.pte.forward_data_to_local() @@ -149,7 +152,7 @@ async def test_process_ping_message(self): self.pte.logger = mock.MagicMock() await self.pte.process_control_message(ControlMessage.Ping, b'') self.pte.logger.debug.assert_called_with('Endpoint TestEndpoint: Received ping request') - mock_send.assert_called_with(ControlMessage.Pong) + mock_send.assert_called_with(ControlMessage.Pong, b'\x00') async def test_start_server(self): with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_open_connection, \ @@ -233,22 +236,19 @@ async def read_side_effect(*args, **kwargs): mock_reader = mock.AsyncMock(spec=asyncio.StreamReader) mock_reader.read.side_effect = read_side_effect - self.pte.connections[1] = (mock_reader, mock.AsyncMock(spec=asyncio.StreamWriter)) - + self.pte.connections[1] = ConnectionInfo(mock_reader, mock.AsyncMock(spec=asyncio.StreamWriter), 0, None, None, datetime.now()) self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) self.pte.kill_server_event.is_set.side_effect = [False, False, True] self.pte.pc = mock.MagicMock(spec=WebRTCConnection) self.pte.pc.data_channel = mock.MagicMock(spec=RTCDataChannel) self.pte.pc.data_channel.readyState = 'open' self.pte.pc.data_channel.bufferedAmount = 0 - self.pte.connection_time[1] = {"start_time": datetime.now()} # Run the task and wait for it to complete task = asyncio.create_task(self.pte.forward_data_to_tunnel(1)) await asyncio.sleep(.01) # Give some time for the task to run task.cancel() # Cancel the task to stop it from running indefinitely - - self.pte.pc.send_message.assert_called_with(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') + self.assertTrue(call(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') in self.pte.pc.send_message.call_args_list) # Test Connection Not Found async def test_forward_data_to_tunnel_no_connection(self): @@ -267,7 +267,7 @@ async def read_side_effect(*args, **kwargs): mock_reader = mock.AsyncMock(spec=asyncio.StreamReader) mock_reader.read.side_effect = read_side_effect mock_writer = mock.AsyncMock(spec=asyncio.StreamWriter) - self.pte.connections[1] = (mock_reader, mock_writer) + self.pte.connections[1] = ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now()) # Mock send_control_message method with mock.patch.object(self.pte, 'send_control_message', @@ -289,7 +289,7 @@ async def read_side_effect(*args, **kwargs): mock_reader = mock.AsyncMock(spec=asyncio.StreamReader) mock_reader.read.side_effect = read_side_effect mock_writer = mock.AsyncMock(spec=asyncio.StreamWriter) - self.pte.connections[1] = (mock_reader, mock_writer) + self.pte.connections[1] = ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now()) # Mock send_control_message method with mock.patch.object(self.pte, 'send_control_message', new_callable=mock.AsyncMock) as mock_send_control_message: @@ -326,10 +326,10 @@ async def test_handle_connection_exception(self): # Test stop_server async def test_stop_server(self): - self.pte.connections = {1: (mock.AsyncMock(), mock.AsyncMock())} + self.pte.connections = {1: ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now())} self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) - with mock.patch.object(self.pte.connections[1][1], 'close', new_callable=mock.AsyncMock) as mock_close, \ - mock.patch.object(self.pte.connections[1][1], 'wait_closed', new_callable=mock.AsyncMock) as mock_wait_closed: + with mock.patch.object(self.pte.connections[1].writer, 'close', new_callable=mock.AsyncMock) as mock_close, \ + mock.patch.object(self.pte.connections[1].writer, 'wait_closed', new_callable=mock.AsyncMock) as mock_wait_closed: await self.pte.stop_server() mock_close.assert_called() mock_wait_closed.assert_called() @@ -339,9 +339,9 @@ async def test_stop_server(self): # Test stop_server with Exception async def test_stop_server_exception(self): - self.pte.connections = {1: (mock.AsyncMock(), mock.AsyncMock())} + self.pte.connections = {1: ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now())} self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) - with mock.patch.object(self.pte.connections[1][1], 'close', side_effect=Exception("Test Exception")): + with mock.patch.object(self.pte.connections[1].writer, 'close', side_effect=Exception("Test Exception")): await self.pte.close_connection(1) self.assertTrue(self.pte.connections == {}) self.assertTrue(self.pte.server is None) @@ -349,13 +349,13 @@ async def test_stop_server_exception(self): # Test close_connection async def test_close_connection(self): - self.pte.connections[1] = (mock.AsyncMock(), mock.AsyncMock()) + self.pte.connections[1] = ConnectionInfo(mock.AsyncMock(), mock.AsyncMock(), 0, None, None, datetime.now()) await self.pte.close_connection(1) - self.assertNotIn(1, self.pte.connections) + self.assertNotIn(1, self.pte.connections.keys()) # Test close_connection with Connection Not Found async def test_close_connection_not_found(self): await self.pte.close_connection(9999) # 9999 is not in self.connections # Check if logger.info was called - self.pte.logger.info.assert_called_with("Endpoint TestEndpoint: Tasks for 9999 not found") + self.pte.logger.info.assert_called_with("Endpoint TestEndpoint: Connection 9999 not found")