diff --git a/.gitignore b/.gitignore index b21811d7b..6b3b79460 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ keeper.txt *.csv Makefile *.db -dr-logs \ No newline at end of file +dr-logs +/.venv* diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 2046cc044..a4c1b2cd3 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -14,7 +14,6 @@ import logging import os.path import queue -import socket import sys import threading import time @@ -38,13 +37,12 @@ from .pam.pam_dto import GatewayActionGatewayInfo, GatewayActionDiscoverInputs, GatewayActionDiscover, \ GatewayActionRotate, \ GatewayActionRotateInputs, GatewayAction, GatewayActionJobInfoInputs, \ - GatewayActionJobInfo, GatewayActionJobCancel, GatewayActionWebRTCSession + GatewayActionJobInfo, GatewayActionJobCancel from .pam.router_helper import router_send_action_to_gateway, print_router_response, \ router_get_connected_gateways, router_set_record_rotation_information, router_get_rotation_schedules, \ - get_router_url, router_get_relay_access_creds + get_router_url from .record_edit import RecordEditMixin -from .tunnel.port_forward.endpoint import establish_symmetric_key, tunnel_encrypt, WebRTCConnection, tunnel_decrypt, \ - TunnelEntrance +from .tunnel.port_forward.endpoint import establish_symmetric_key, WebRTCConnection, TunnelEntrance, READ_TIMEOUT from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades from ..display import bcolors from ..error import CommandError @@ -1577,30 +1575,37 @@ def gather_tabel_row_data(thread): minutes = 0 seconds = 0 - if thread.get('started'): - run_time = datetime.now() - thread.get('started') - hours, remainder = divmod(run_time.seconds, 3600) - minutes, seconds = divmod(remainder, 60) + entrance = thread.get('entrance') # # row.append(f"{thread.get('name', '')}") row.append(f"{bcolors.OKBLUE}{thread.get('convo_id', '')}{bcolors.ENDC}") - row.append(f"{thread.get('host')}" if thread.get('host') else '') + row.append(f"{thread.get('host', '')}") - if not thread.get('entrance'): + if entrance is not None and entrance.print_ready_event.is_set(): + if thread.get('started'): + run_time = datetime.now() - thread.get('started') + hours, remainder = divmod(run_time.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + row.append( + f"{bcolors.OKBLUE}{entrance._port}{bcolors.ENDC}" + ) + else: row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") + row.append(f"{thread.get('record_uid', '')}") + if entrance is not None and entrance.print_ready_event.is_set(): + text_line = "" + if run_time: + if run_time.days == 1: + text_line += f"{run_time.days} day " + elif run_time.days > 1: + text_line += f"{run_time.days} days " + text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else '' + text_line += f"{minutes} min " + text_line += f"{seconds} sec" + row.append(text_line) else: - row.append(f"{bcolors.OKBLUE}{thread.get('entrance')._port}{bcolors.ENDC}" if thread.get('entrance') else '') - row.append(f"{thread.get('record_uid')}" if thread.get('record_uid') else '') - text_line = "" - if run_time: - if run_time.days == 1: - text_line += f"{run_time.days} day " - elif run_time.days > 1: - text_line += f"{run_time.days} days " - text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else '' - text_line += f"{minutes} min " - text_line += f"{seconds} sec" - row.append(text_line) + row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") return row if not params.tunnel_threads: @@ -1618,6 +1623,26 @@ def gather_tabel_row_data(thread): dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None) +def clean_up_tunnel(params, convo_id): + tunnel_data = params.tunnel_threads.get(convo_id) + if tunnel_data: + kill_server_event = tunnel_data.get("kill_server_event") + if kill_server_event: + if not kill_server_event.is_set(): + kill_server_event.set() + # whatever the read timeout is, wait for 2 seconds more + time.sleep(READ_TIMEOUT + 2) + p = tunnel_data.get("process", None) + if p and p.is_alive(): + p.join() + if params.tunnel_threads.get(convo_id): + del params.tunnel_threads[convo_id] + if params.tunnel_threads_queue.get(convo_id): + del params.tunnel_threads_queue[convo_id] + else: + print(f"{bcolors.WARNING}No tunnel data found to remove for {convo_id}{bcolors.ENDC}") + + class PAMTunnelStopCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-stop-command') pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID') @@ -1632,16 +1657,8 @@ def execute(self, params, **kwargs): tunnel_data = params.tunnel_threads.get(convo_id, None) if not tunnel_data: - raise CommandError('tunnel stop', f"No tunnel data found for {convo_id}") - - connect_task = tunnel_data.get("connect_task", None) - if connect_task: - connect_task.cancel() - count = 0 - while params.tunnel_threads.get(convo_id) and count < 10: - count += .1 - time.sleep(.1) - + raise CommandError('tunnel stop', f"No tunnel data to remove found for {convo_id}") + clean_up_tunnel(params, convo_id) return @@ -1656,6 +1673,8 @@ def execute(self, params, **kwargs): convo_id = kwargs.get('uid') if not convo_id: raise CommandError('tunnel tail', '"uid" argument is required') + if convo_id not in params.tunnel_threads: + raise CommandError('tunnel tail', f"Tunnel UID {convo_id} not found") log_queue = params.tunnel_threads_queue.get(convo_id) @@ -1791,90 +1810,58 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, # Get symmetric key symmetric_key = establish_symmetric_key(client_private_key_pem, gateway_public_key) - response = router_get_relay_access_creds(params=params) - # Set up the pc print_ready_event = asyncio.Event() - pc = WebRTCConnection(endpoint_name=convo_id, print_ready_event=print_ready_event, - username=response.username, password=response.password, logger=logger) + kill_server_event = asyncio.Event() + pc = WebRTCConnection(endpoint_name=convo_id, params=params, record_uid=record_uid, gateway_uid=gateway_uid, + symmetric_key=symmetric_key, print_ready_event=print_ready_event, + kill_server_event=kill_server_event, logger=logger) - # make webRTC sdp offer - try: - offer = await pc.make_offer() - except socket.gaierror: - print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}") - return - except Exception as e: - raise CommandError('tunnel start', f'Error making WebRTC offer: {e}') - encrypted_offer = tunnel_encrypt(symmetric_key, offer) - logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY") - - ''' - 'inputs': { - 'conversationType': ['tunnel', 'guacd'] - 'kind': ['start', 'disconnect'], - 'recordUid': record_uid, <-- this is the record UID of the PAM resource record - with Network information - 'listenerName': NAME OF LISTENER, <-- Used in logging (not required) - 'offer': encrypted_WebRTC_sdp_offer, <-- WebRTC SDP offer encrypted with symmetric key - 'allow_control': True, <-- only for guacd, False = readonly session (default True) - 'guacamole_client_id: guacamole_client_id, <-- only for guacd, Connect to an existing guacd session - 'userRecordUid': userRecordUid, <-- only for guacd, User record UID to connect for session - } - ''' - # TODO create objects for WebRTC inputs - router_response = router_send_action_to_gateway( - params=params, - gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': convo_id, "recordUid": record_uid, - "offer": encrypted_offer, 'kind': 'start', - 'conversationType': 'tunnel'}), - message_type=pam_pb2.CMT_GENERAL, - is_streaming=False, - destination_gateway_uid_str=gateway_uid, - gateway_timeout=30000 - ) - if not router_response: - return - gateway_response = router_response.get('response', {}) - if not gateway_response: - raise Exception(f"Error getting response from the Gateway: {router_response}") try: - payload = json.loads(gateway_response.get('payload', None)) - if not payload: - raise Exception(f"Error getting payload from the Gateway response: {gateway_response}") + await pc.signal_channel('start') except Exception as e: - raise Exception(f"Error getting payload from the Gateway response: {e}") - - if payload.get('is_ok', False) is False or payload.get('progress_status') == 'Error': - raise Exception(f"Error getting payload from the Gateway response: {payload.get('data')}") - - encrypted_answer = payload.get('data', None) - if not encrypted_answer: - raise Exception(f"Error getting data from the Gateway response payload: {payload}") - - # decrypt the sdp answer - answer = tunnel_decrypt(symmetric_key, encrypted_answer) - await pc.accept_answer(answer) + CommandError('tunnel start', f"{e}") logger.debug("starting private tunnel") private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc, print_ready_event=print_ready_event, logger=logger, - connect_task = params.tunnel_threads[convo_id].get("connect_task", None)) + connect_task=params.tunnel_threads[convo_id].get("connect_task", None), + kill_server_event=kill_server_event) t1 = asyncio.create_task(private_tunnel.start_server()) - params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel}) + params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel, + "kill_server_event": kill_server_event}) logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------") - await asyncio.gather(t1, private_tunnel.reader_task) + try: + await asyncio.gather(t1, private_tunnel.reader_task) + except asyncio.CancelledError: + pass + finally: + logger.debug("--> STOP LISTENING FOR MESSAGES FROM GATEWAY --------") def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, gateway_public_key_bytes, client_private_key): + + def custom_exception_handler(loop, context): + # Check if the exception is present in the context + if "exception" in context: + exception = context["exception"] + if isinstance(exception, ConnectionError): + # Handle only ConnectionError + logging.debug(f"Caught ConnectionError in asyncio: {exception}") + else: + # Log the default message if no exception is found + logging.error(context["message"]) + loop = None try: + # Create a new asyncio event loop and set the custom exception handler loop = asyncio.new_event_loop() params.tunnel_threads[convo_id].update({"loop": loop}) asyncio.set_event_loop(loop) + loop.set_exception_handler(custom_exception_handler) output_queue = queue.Queue(maxsize=500) params.tunnel_threads_queue[convo_id] = output_queue # Create a Task from the coroutine @@ -1899,6 +1886,8 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, pass except SocketNotConnectedException as es: print(f"{bcolors.FAIL}Socket not connected exception in connection {convo_id}: {es}{bcolors.ENDC}") + except KeyboardInterrupt: + print(f"{bcolors.OKBLUE}Exiting: {convo_id}{bcolors.ENDC}") except Exception as e: print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}") finally: @@ -1932,13 +1921,14 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}") def execute(self, params, **kwargs): - version = [3, 11, 0] - # Check for Python 3.11+ + # https://pypi.org/project/aiortc/ + # aiortc Requires: Python >=3.8 + version = [3, 8, 0] major_version = sys.version_info.major minor_version = sys.version_info.minor micro_version = sys.version_info.micro - if (major_version, minor_version, micro_version) <= (version[0], version[1], version[2]): + if (major_version, minor_version, micro_version) < (version[0], version[1], version[2]): print(f"{bcolors.FAIL}This code requires Python {version[0]}.{version[1]}.{version[2]} or higher. " f"You are using {major_version}.{minor_version}.{micro_version}.{bcolors.ENDC}") return @@ -1952,6 +1942,10 @@ def execute(self, params, **kwargs): gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils) + if not gateway_public_key_bytes: + print(f"{bcolors.FAIL}Could not retrieve public key for gateway {gateway_uid}{bcolors.ENDC}") + return + api.sync_down(params) record = vault.KeeperRecord.load(params, record_uid) if not isinstance(record, vault.TypedRecord): @@ -1991,22 +1985,73 @@ def execute(self, params, **kwargs): t.daemon = True t.start() - params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port, - "started": datetime.now(), "record_uid": record_uid}) + if not params.tunnel_threads.get(convo_id): + params.tunnel_threads[convo_id] = {"convo_id": convo_id, "thread": t, "host": host, "port": port, + "started": datetime.now(), "record_uid": record_uid} + else: + params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port, + "started": datetime.now(), "record_uid": record_uid}) count = 0 wait_time = 120 + entrance = None + # run once + while count == 0: + while count < wait_time: + if params.tunnel_threads.get(convo_id): + entrance = params.tunnel_threads[convo_id].get("entrance", None) + if entrance: + break + else: + break + count += .1 + time.sleep(.1) + + def print_fail(): + fail_dynamic_length = len("| Endpoint : failed to start..") + len(convo_id) + + clean_up_tunnel(params, convo_id) + time.sleep(.5) + # Dashed line adjusted to the length of the middle line + fail_dashed_line = '+' + '-' * fail_dynamic_length + '+' + print(f'\n{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}') + print(f'{bcolors.FAIL}| Endpoint {convo_id}{bcolors.ENDC} failed to start..') + print(f'{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}\n') + + if entrance is not None: + while not entrance.print_ready_event.is_set() and count < wait_time * 2: + count += .1 + time.sleep(.1) + + if entrance.print_ready_event.is_set(): + # Sleep a little bit to print out last + time.sleep(.5) + host = host + ":" if host else '' + # Total length of the dynamic parts (endpoint name, host, and port) + dynamic_length = \ + (len("| Endpoint : Listening on port: ") + len(convo_id) + len(host) + len(str(port))) + + # Dashed line adjusted to the length of the middle line + dashed_line = '+' + '-' * dynamic_length + '+' + + # Print statements + print(f'\n{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}') + print( + f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{convo_id}{bcolors.ENDC}' + f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}' + f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}') + print( + f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}' + f'{bcolors.OKBLUE}pam tunnel tail ' + + (f'-- ' if convo_id[0] == '-' else '') + + f'{convo_id}{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}' + f'{bcolors.OKBLUE}pam tunnel stop ' + + (f'-- ' if convo_id[0] == '-' else '') + + f'{convo_id}{bcolors.ENDC}\n') + else: + print_fail() + else: + print_fail() - while count < wait_time and not params.tunnel_threads[convo_id].get("entrance"): - count += .1 - time.sleep(.1) - - if count >= wait_time: - # There could be an error that happened in the thread return - return - - entrance = params.tunnel_threads[convo_id].get("entrance") - while not entrance.print_ready_event.is_set() and count < wait_time * 2: - count += .1 - time.sleep(.1) - # After it is set to True the print waits for .5 seconds then prints so we need to wait a little longer - time.sleep(1) diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index f3e7c19fd..751982bd4 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -1,5 +1,6 @@ import asyncio import enum +import json import logging import os import secrets @@ -7,6 +8,7 @@ import string import time from typing import Optional, Dict, Tuple +from datetime import datetime from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer from cryptography.hazmat.primitives import hashes @@ -14,21 +16,25 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.utils import int_to_bytes -from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string +from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string, string_to_bytes +from keepercommander.commands.pam.pam_dto import GatewayActionWebRTCSession +from keepercommander.commands.pam.router_helper import router_get_relay_access_creds, router_send_action_to_gateway from keepercommander.display import bcolors +from keepercommander.params import KeeperParams +from keepercommander.proto import pam_pb2 logging.getLogger('aiortc').setLevel(logging.WARNING) logging.getLogger('aioice').setLevel(logging.WARNING) - -BUFFER_TRUNCATION_THRESHOLD = 1400 +#TODO add why 9 is the length of the protocol +BUFFER_TRUNCATION_THRESHOLD = 16000 - 9 # 16 Kbytes max https://viblast.com/blog/2015/2/5/webrtc-data-channel-message-size/ so we will use the max minus 9 bytes for the protocol READ_TIMEOUT = 10 CONTROL_MESSAGE_NO_LENGTH = 2 CONNECTION_NO_LENGTH = DATA_LENGTH = 4 -LATENCY_COUNT = 5 NONCE_LENGTH = 12 SYMMETRIC_KEY_LENGTH = RANDOM_LENGTH = 32 TERMINATOR = b';' +BUFFER_THRESHOLD = 134217728 * .90 # 16 MiB max https://viblast.com/blog/2015/2/25/webrtc-bufferedamount/ so we will use 14.4 MiB or 90% of the max, because in some cases if the max is reached the channel will close class ConnectionNotFoundException(Exception): @@ -135,14 +141,99 @@ def tunnel_decrypt(symmetric_key: AESGCM, encrypted_data: str): class WebRTCConnection: - def __init__(self, endpoint_name: Optional[str] = "Keeper PAM Tunnel", - print_ready_event: Optional[asyncio.Event] = None, username: Optional[str] = None, - password: Optional[str] = None, logger: Optional[logging.Logger] = None): + def __init__(self, endpoint_name: str, params: KeeperParams, record_uid, gateway_uid, symmetric_key, + print_ready_event: asyncio.Event, kill_server_event: asyncio.Event, + logger: Optional[logging.Logger] = None): + self._pc = None self.web_rtc_queue = asyncio.Queue() self.closed = False self.data_channel = None self.print_ready_event = print_ready_event + self.logger = logger + self.endpoint_name = endpoint_name + self.params = params + self.record_uid = record_uid + self.gateway_uid = gateway_uid + self.symmetric_key = symmetric_key + self.kill_server_event = kill_server_event + try: + self.peer_ice_config() + self.setup_data_channel() + self.setup_event_handlers() + except Exception as e: + raise Exception(f'Error setting up WebRTC connection: {e}') + + async def signal_channel(self, kind: str): + + # make webRTC sdp offer + try: + if kind == 'start': + offer = await self.make_offer() + else: + raise Exception(f'Invalid kind: {kind}') + except socket.gaierror: + print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}") + return + except Exception as e: + raise Exception(f'Error making WebRTC offer: {e}') + data = {"offer": bytes_to_base64(offer), 'kind': kind, 'conversationType': 'tunnel'} + string_data = json.dumps(data) + bytes_data = string_to_bytes(string_data) + encrypted_data = tunnel_encrypt(self.symmetric_key, bytes_data) + self.logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY") + ''' + 'inputs': { + 'recordUid': record_uid, <-- this is the record UID of the PAM resource record + with Network information (REQUIRED) + + 'data': { <-- All data is encrypted with symmetric key (REQUIRED) + 'conversationType': ['tunnel', 'guacd'] <-- What type of conversation is this + 'kind': ['start', 'disconnect'], <-- What command to run (REQUIRED) + 'offer': encrypted_WebRTC_sdp_offer, <-- WebRTC SDP offer, base64 encoded + 'allow_control': True, <-- only for guacd, False = readonly session + 'guacamole_client_id: guacamole_client_id, <-- only for guacd, an existing guacd session + 'userRecordUid': userRecordUid, <-- only for guacd, User record UID to connect with + 'conversations': [] <-- only for disconnect, list of conversations to close + } + } + ''' + # TODO create objects for WebRTC inputs + router_response = router_send_action_to_gateway( + params=self.params, + gateway_action=GatewayActionWebRTCSession(inputs={"recordUid": self.record_uid, "data": encrypted_data}), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=self.gateway_uid, + gateway_timeout=30000 + ) + if not router_response: + self.kill_server_event.set() + return + gateway_response = router_response.get('response', {}) + if not gateway_response: + raise Exception(f"Error getting response from the Gateway: {router_response}") + try: + payload = json.loads(gateway_response.get('payload', None)) + if not payload: + raise Exception(f"Error getting payload from the Gateway response: {gateway_response}") + except Exception as e: + raise Exception(f"Error getting payload from the Gateway response: {e}") + + if payload.get('is_ok', False) is False or payload.get('progress_status') == 'Error': + raise Exception(f"Error getting payload from the Gateway response: {payload.get('data')}") + + encrypted_answer = payload.get('data', None) + if not encrypted_answer: + raise Exception(f"Error getting data from the Gateway response payload: {payload}") + # decrypt the sdp answer + answer = tunnel_decrypt(self.symmetric_key, encrypted_answer) + await self.accept_answer(answer) + + self.logger.debug("starting private tunnel") + + def peer_ice_config(self): + response = router_get_relay_access_creds(params=self.params) # Define the STUN server URL # To use Google's STUN server ''' @@ -159,17 +250,13 @@ def __init__(self, endpoint_name: Optional[str] = "Keeper PAM Tunnel", # Create an RTCIceServer instance for the STUN server stun_server = RTCIceServer(urls=stun_url) # Define the TURN server URL and credentials - turn_url = f"turn:{relay_url}?transport=udp" + turn_url = f"turn:{relay_url}" # Create an RTCIceServer instance for the TURN server with credentials - turn_server = RTCIceServer(urls=turn_url, username=username, credential=password) + turn_server = RTCIceServer(urls=turn_url, username=response.username, credential=response.password) # Create a new RTCConfiguration with both STUN and TURN servers config = RTCConfiguration(iceServers=[stun_server, turn_server]) self._pc = RTCPeerConnection(config) - self.setup_data_channel() - self.setup_event_handlers() - self.logger = logger - self.endpoint_name = endpoint_name async def make_offer(self): offer = await self._pc.createOffer() @@ -250,7 +337,7 @@ async def close_webrtc_connection(self): # Close the peer connection if self._pc: await self._pc.close() - self.logger.error(f'Endpoint {self.endpoint_name}: "Peer connection closed') + self.logger.error(f'Endpoint {self.endpoint_name}: Peer connection closed') # Clear the asyncio queue if self.web_rtc_queue: @@ -333,6 +420,7 @@ def __init__(self, print_ready_event, # type: asyncio.Event logger = None, # type: logging.Logger connect_task = None, # type: asyncio.Task + kill_server_event = None # type: asyncio.Event ): # type: (...) -> None self.closing = False self.ping_time = None @@ -348,17 +436,34 @@ def __init__(self, self.is_connected = True self.reader_task = asyncio.create_task(self.start_reader()) self.to_tunnel_tasks = {} - self.kill_server_event = asyncio.Event() + self.kill_server_event = kill_server_event self.pc = pc self.print_ready_event = print_ready_event - self.server_task = None self.connect_task = connect_task + self.connection_time = {} - async def send_to_web_rtc(self, data): + async def send_to_web_rtc(self, data, connection_no=0): + # TODO: figure out networking issue here if self.pc.is_data_channel_open(): - self.pc.send_message(data) - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) + try: + sleep_count = 0 + while (self.pc.data_channel is not None and + self.pc.data_channel.bufferedAmount >= BUFFER_THRESHOLD and + not self.kill_server_event.is_set() and + self.pc.is_data_channel_open()): + self.logger.debug(f"{bcolors.WARNING}Buffered amount is too high ({sleep_count * 100}) " + f"{self.pc.data_channel.bufferedAmount}{bcolors.ENDC}") + await asyncio.sleep(sleep_count) + sleep_count += .01 + self.logger.debug(f'Endpoint {self.endpoint_name}: buffer size: {self.pc.data_channel.bufferedAmount}' + + f', time since start: {datetime.now() - self.connection_time[connection_no]["start_time"]}' if connection_no > 0 else '') + self.pc.send_message(data) + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) + + except Exception as e: + self.logger.error(f'Endpoint {self.endpoint_name}: Error sending message: {e}') + await asyncio.sleep(0.1) else: if self.print_ready_event.is_set(): self.logger.error(f'Endpoint {self.endpoint_name}: Data channel is not open. Data not sent.') @@ -515,7 +620,7 @@ async def forward_data_to_local(self): continue elif isinstance(data, bytes): self.logger.debug(f"Endpoint {self.endpoint_name}: Got data from WebRTC connection " - f"{len(data)} bytes)") + f"{len(data)} bytes") buff += data else: # Yield control back to the event loop for other tasks to execute @@ -581,7 +686,7 @@ async def forward_data_to_tunnel(self, con_no): else: buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR - await self.send_to_web_rtc(buffer) + await self.send_to_web_rtc(buffer, con_no) else: # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) @@ -607,6 +712,7 @@ async def handle_connection(self, reader, writer): # type: (asyncio.StreamReade connection_no = self.connection_no self.connection_no += 1 self.connections[connection_no] = (reader, writer) + self.connection_time[connection_no] = {"start_time": datetime.now()} self.logger.debug(f"Endpoint {self.endpoint_name}: Created local connection {connection_no}") @@ -635,84 +741,32 @@ async def start_server(self): # type: (...) -> None except Exception as e: self.logger.error(f"Endpoint {self.endpoint_name}: Error while finding open port: {e}") - await self.print_not_ready() + self.kill_server_event.set() return if not self._port: self.logger.error(f"Endpoint {self.endpoint_name}: No open ports found for local server") - await self.print_not_ready() + self.kill_server_event.set() return try: self.server = await asyncio.start_server(self.handle_connection, family=socket.AF_INET, host=self.host, port=self._port) async with self.server: - self.server_task = await asyncio.create_task(self.print_ready(self.host, self._port, self.print_ready_event)) await self.server.serve_forever() except ConnectionRefusedError as er: self.logger.error(f"Endpoint {self.endpoint_name}: Connection Refused while starting server: {er}") - await self.print_not_ready() + self.kill_server_event.set() return except OSError as er: self.logger.error(f"Endpoint {self.endpoint_name}: OS Error while starting server: {er}") - await self.print_not_ready() + self.kill_server_event.set() return except Exception as e: self.logger.error(f"Endpoint {self.endpoint_name}: Error while starting server: {e}") - await self.print_not_ready() - return - - async def print_not_ready(self): - print(f'\n{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') - print(f'{bcolors.FAIL}| Endpoint {self.endpoint_name}{bcolors.ENDC} failed to start') - print(f'{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}\n') - self.kill_server_event.set() - - async def print_ready(self, host, # type: str - port, # type: int - print_ready_event, # type: asyncio.Event - ): # type: (...) -> None - """ - pretty prints the endpoint name and host:port after the tunnels are set up - """ - wait_for_server = READ_TIMEOUT * 6 - try: - await asyncio.wait_for(print_ready_event.wait(), wait_for_server) - except TimeoutError: - await self.print_not_ready() + self.kill_server_event.set() return - if not self.server or not self.server.is_serving() if self.server else False: - await self.print_not_ready() - return - - # Sleep a little bit to print out last - await asyncio.sleep(.5) - host = host + ":" if host else '' - # Total length of the dynamic parts (endpoint name, host, and port) - dynamic_length = \ - (len("| Endpoint : Listening on port: ") + len(self.endpoint_name) + len(host) + len(str(port))) - - # Dashed line adjusted to the length of the middle line - dashed_line = '+' + '-' * dynamic_length + '+' - - # Print statements - print(f'\n{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}') - print( - f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{self.endpoint_name}{bcolors.ENDC}' - f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}' - f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}' - f'{bcolors.OKBLUE}pam tunnel tail ' + - (f'-- ' if self.endpoint_name[0] == '-' else '') + - f'{self.endpoint_name}{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}' - f'{bcolors.OKBLUE}pam tunnel stop ' + - (f'-- ' if self.endpoint_name[0] == '-' else '') + - f'{self.endpoint_name}{bcolors.ENDC}\n') - async def stop_server(self): if self.closing: return @@ -724,7 +778,9 @@ async def stop_server(self): for c in list(self.connections): await self.close_connection(c) - self.kill_server_event.set() + if self.kill_server_event is not None: + if not self.kill_server_event.is_set(): + self.kill_server_event.set() try: # close aiortc data channel await self.pc.close_webrtc_connection() @@ -755,7 +811,8 @@ async def close_connection(self, connection_no): f"Endpoint {self.endpoint_name}: Timed out while trying to close connection " f"{connection_no}") - del self.connections[connection_no] + if connection_no in self.connections: + del self.connections[connection_no] self.logger.info(f"Endpoint {self.endpoint_name}: Closed connection {connection_no}") else: self.logger.info(f"Endpoint {self.endpoint_name}: Connection {connection_no} not found") diff --git a/unit-tests/pam-tunnel/test_pam_tunnel.py b/unit-tests/pam-tunnel/test_pam_tunnel.py index d4bffaac6..bdfac1724 100644 --- a/unit-tests/pam-tunnel/test_pam_tunnel.py +++ b/unit-tests/pam-tunnel/test_pam_tunnel.py @@ -2,7 +2,7 @@ import unittest from unittest import mock -if sys.version_info >= (3, 11): +if sys.version_info >= (3, 8): import datetime import socket import string diff --git a/unit-tests/pam-tunnel/test_private_tunnel.py b/unit-tests/pam-tunnel/test_private_tunnel.py index 7df78c382..4bfba7bbc 100644 --- a/unit-tests/pam-tunnel/test_private_tunnel.py +++ b/unit-tests/pam-tunnel/test_private_tunnel.py @@ -2,13 +2,13 @@ import unittest from unittest import mock -if sys.version_info >= (3, 15): +if sys.version_info >= (3, 8): import asyncio import logging import socket from aiortc import RTCDataChannel - from cryptography.utils import int_to_bytes + from datetime import datetime from keepercommander import utils from keepercommander.commands.tunnel.port_forward.endpoint import (TunnelEntrance, ControlMessage, CONTROL_MESSAGE_NO_LENGTH, CONNECTION_NO_LENGTH, @@ -22,6 +22,8 @@ async def asyncSetUp(self): self.host = 'localhost' self.port = 8080 self.endpoint_name = 'TestEndpoint' + self.kill_server_event = asyncio.Event() + self.connect_task = mock.MagicMock(spec=asyncio.Task) self.private_key, self.private_key_str = new_private_key() self.logger = mock.MagicMock(spec=logging) @@ -29,10 +31,11 @@ async def asyncSetUp(self): self.tunnel_symmetric_key = utils.generate_aes_key() self.pc = mock.MagicMock(sepc=WebRTCConnection) self.pc.data_channel.readyState = 'open' + self.pc.data_channel.bufferedAmount = 0 self.incoming_queue = mock.MagicMock(sepc=asyncio.Queue()) self.print_ready_event = asyncio.Event() self.pte = TunnelEntrance(self.host, self.port, self.endpoint_name, self.pc, - self.print_ready_event, self.logger) + self.print_ready_event, self.logger, self.connect_task, self.kill_server_event) async def set_queue_side_effect(self): data = b'some_data' @@ -54,7 +57,7 @@ async def test_send_control_message(self): self.pte.tls_writer = mock.MagicMock(spec=asyncio.StreamWriter) # Mock write and drain methods - with mock.patch.object(self.pte.pc.data_channel, 'send', new_callable=mock.AsyncMock) as mock_send: + with mock.patch.object(self.pte.pc, 'send_message', new_callable=mock.AsyncMock) as mock_send_message: # Define the control message and optional data control_message = ControlMessage.Ping @@ -71,7 +74,7 @@ async def test_send_control_message(self): expected_data += optional_data + TERMINATOR # Assertions - mock_send.assert_called_once_with(expected_data) + mock_send_message.assert_called_once_with(expected_data) async def test_send_control_message_with_error(self): # Initialize self.pte.tls_writer with a mock object @@ -79,7 +82,7 @@ async def test_send_control_message_with_error(self): self.pte.logger = mock.MagicMock() # Set side effect to raise an exception - self.pte.pc.data_channel.send.side_effect = Exception("Mocked Exception") + self.pte.pc.send_message.side_effect = Exception("Mocked Exception") # Define the control message and optional data control_message = ControlMessage.Ping @@ -89,39 +92,41 @@ async def test_send_control_message_with_error(self): await self.pte.send_control_message(control_message, optional_data) # Prepare the expected error log message - expected_error_message = (f"Endpoint {self.pte.endpoint_name}: Error while sending private control message: " - f"Mocked Exception") + expected_error_message = (f"Endpoint {self.pte.endpoint_name}: Error sending message: Mocked Exception") # Assertions self.pte.logger.error.assert_called_once_with(expected_error_message) async def test_forward_data_to_local_normal(self): await self.set_queue_side_effect() - - self.pte.connections = {1: (None, mock.MagicMock(spec=asyncio.StreamWriter))} + connection = (None, mock.MagicMock(spec=asyncio.StreamWriter)) + self.pte.connections = {1: connection} self.pte.logger = mock.MagicMock() self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) - self.pte.kill_server_event.is_set.side_effect = [False, False, True] - - await self.pte.forward_data_to_local() - + self.pte.stop_server = mock.AsyncMock() + self.pte.kill_server_event.is_set.side_effect = [False, False, False, True, True] + self.pte.pc.closed = False + with mock.patch.object(self.pte, 'stop_server', new_callable=mock.AsyncMock) as mock_close: + mock_close.side_effect = mock.MagicMock(spec=asyncio.Task) + await self.pte.forward_data_to_local() + + self.assertTrue(len(self.pte.connections) == 1) self.pte.connections[1][1].write.assert_called_with(b'some_data') self.pte.connections[1][1].drain.assert_called_once() - self.assertTrue(self.pte.logger.method_calls[3] == (mock.call.debug('Endpoint TestEndpoint: Forwarding private ' - 'data to local for connection 1 (9)'))) async def test_forward_data_to_local_error(self): await self.set_queue_side_effect() - self.pte.connections = {1: (None, mock.MagicMock(spec=asyncio.StreamWriter))} + connection = (None, mock.MagicMock(spec=asyncio.StreamWriter)) + self.pte.connections = {1: connection} self.pte.logger = mock.MagicMock() self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) - self.pte.kill_server_event.is_set.side_effect = [False, False, True] + self.pte.kill_server_event.is_set.side_effect = [False, False, False, False, True, True] + self.pte.pc.closed = False self.pte.connections[1][1].write.side_effect = Exception("Some error") await self.pte.forward_data_to_local() - self.pte.logger.error.assert_called_with("Endpoint TestEndpoint: Error while forwarding private data to " - "local: Some error") + self.pte.logger.debug.assert_called_with("Endpoint TestEndpoint: Closing tunnel") async def test_process_close_connection_message(self): with mock.patch.object(self.pte, 'close_connection', new_callable=mock.AsyncMock) as mock_close: @@ -133,7 +138,7 @@ async def test_process_pong_message(self): self.pte.logger = mock.MagicMock() await self.pte.process_control_message(ControlMessage.Pong, b'') expected_calls = [ - mock.call('Endpoint TestEndpoint: Received private pong request') + mock.call('Endpoint TestEndpoint: Received pong request') ] self.pte.logger.debug.assert_has_calls(expected_calls) self.assertEqual(self.pte._ping_attempt, 0) @@ -143,7 +148,7 @@ async def test_process_ping_message(self): with mock.patch.object(self.pte, 'send_control_message', new_callable=mock.AsyncMock) as mock_send: self.pte.logger = mock.MagicMock() await self.pte.process_control_message(ControlMessage.Ping, b'') - self.pte.logger.debug.assert_called_with('Endpoint TestEndpoint: Received private ping request') + self.pte.logger.debug.assert_called_with('Endpoint TestEndpoint: Received ping request') mock_send.assert_called_with(ControlMessage.Pong) async def test_start_server(self): @@ -154,19 +159,16 @@ async def test_start_server(self): host='localhost', port=self.port) async def test_start_server_normal(self): - with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_open_connection, \ - mock.patch.object(self.pte, 'print_ready', new_callable=mock.AsyncMock) as print_ready: - mock_open_connection.return_value = mock.MagicMock(spec=asyncio.Server) + with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_open_connection: + mock_open_connection.return_value = mock.MagicMock(spec=asyncio.AbstractServer) self.pte.logger = mock.MagicMock() await self.pte.start_server() - - print_ready.assert_called_once() + self.assertTrue(self.pte.server is not None) async def test_start_server_connection_refused_error(self): - with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server, \ - mock.patch.object(self.pte, 'stop_server', new_callable=mock.AsyncMock) as mock_stop: + with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server: mock_start_server.side_effect = ConnectionRefusedError self.pte.logger = mock.MagicMock() @@ -174,24 +176,22 @@ async def test_start_server_connection_refused_error(self): self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: Connection Refused while starting ' 'server: ') - mock_stop.assert_called() + self.assertTrue(self.pte.kill_server_event.is_set()) self.assertTrue(self.pte.server is None) async def test_start_server_timeout_error(self): - with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server, \ - mock.patch.object(self.pte, 'stop_server', new_callable=mock.AsyncMock) as mock_stop: + with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server: mock_start_server.side_effect = TimeoutError self.pte.logger = mock.MagicMock() await self.pte.start_server() self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: OS Error while starting server: ') - mock_stop.assert_called() + self.assertTrue(self.pte.kill_server_event.is_set()) self.assertTrue(self.pte.server is None) async def test_start_server_os_error(self): - with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server, \ - mock.patch.object(self.pte, 'stop_server', new_callable=mock.AsyncMock) as mock_stop: + with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server: mock_start_server.side_effect = OSError("Some OS Error") self.pte.logger = mock.MagicMock() @@ -199,12 +199,11 @@ async def test_start_server_os_error(self): self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: OS Error while starting server: ' 'Some OS Error') - mock_stop.assert_called() + self.assertTrue(self.pte.kill_server_event.is_set()) self.assertTrue(self.pte.server is None) async def test_start_server_generic_exception(self): - with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server, \ - mock.patch.object(self.pte, 'stop_server', new_callable=mock.AsyncMock) as mock_stop: + with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_start_server: mock_start_server.side_effect = Exception("Some generic exception") self.pte.logger = mock.MagicMock() @@ -212,7 +211,7 @@ async def test_start_server_generic_exception(self): self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: Error while starting server: ' 'Some generic exception') - mock_stop.assert_called() + self.assertTrue(self.pte.kill_server_event.is_set()) self.assertTrue(self.pte.server is None) # Test Successful Data Forwarding @@ -241,13 +240,15 @@ async def read_side_effect(*args, **kwargs): self.pte.pc = mock.MagicMock(spec=WebRTCConnection) self.pte.pc.data_channel = mock.MagicMock(spec=RTCDataChannel) self.pte.pc.data_channel.readyState = 'open' + self.pte.pc.data_channel.bufferedAmount = 0 + self.pte.connection_time[1] = {"start_time": datetime.now()} # Run the task and wait for it to complete task = asyncio.create_task(self.pte.forward_data_to_tunnel(1)) await asyncio.sleep(.01) # Give some time for the task to run task.cancel() # Cancel the task to stop it from running indefinitely - self.pte.pc.data_channel.send.assert_called_with(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') + self.pte.pc.send_message.assert_called_with(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') # Test Connection Not Found async def test_forward_data_to_tunnel_no_connection(self): @@ -323,48 +324,28 @@ async def test_handle_connection_exception(self): with self.assertRaises(Exception): await self.pte.handle_connection(mock_reader, mock_writer) - # Test print_not_ready - async def test_print_not_ready(self): - with mock.patch.object(self.pte, 'send_control_message', - new_callable=mock.AsyncMock) as mock_send_control_message: - await self.pte.print_not_ready() - mock_send_control_message.assert_called_with(ControlMessage.CloseConnection, int_to_bytes(0)) - - # Test print_ready - async def test_print_ready(self): - with mock.patch('builtins.print') as mock_print: - await self.pte.print_ready('localhost', 8080, mock.AsyncMock()) - - # Check if print was called (optional) - mock_print.assert_called() - - # Test print_ready with TimeoutError - async def test_print_ready_timeout_error_forwarder(self): - print_event = mock.AsyncMock(spec=asyncio.Event) - print_event.wait.side_effect = asyncio.TimeoutError() - with mock.patch.object(self.pte, 'print_not_ready', new_callable=mock.AsyncMock) as mock_print_not_ready: - await self.pte.print_ready('localhost', 8080, print_event) - - # Check if logger.debug was called - self.pte.logger.debug.assert_called_with("Endpoint TestEndpoint: Timed out waiting for private tunnel to start") - # Check if print was called (optional) - mock_print_not_ready.assert_called() - # Test stop_server async def test_stop_server(self): - self.pte.server = mock.AsyncMock(spec=asyncio.Server) - with mock.patch.object(self.pte.server, 'close', new_callable=mock.AsyncMock) as mock_close, \ - mock.patch.object(self.pte.server, 'wait_closed', new_callable=mock.AsyncMock) as mock_wait_closed: + self.pte.connections = {1: (mock.AsyncMock(), mock.AsyncMock())} + self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) + with mock.patch.object(self.pte.connections[1][1], 'close', new_callable=mock.AsyncMock) as mock_close, \ + mock.patch.object(self.pte.connections[1][1], 'wait_closed', new_callable=mock.AsyncMock) as mock_wait_closed: await self.pte.stop_server() mock_close.assert_called() mock_wait_closed.assert_called() + self.assertTrue(self.pte.connections == {}) + self.assertTrue(self.pte.server is None) + self.assertTrue(self.pte.kill_server_event.is_set()) # Test stop_server with Exception async def test_stop_server_exception(self): - self.pte.server = mock.AsyncMock(spec=asyncio.Server) - with mock.patch.object(self.pte.server, 'close', side_effect=Exception("Test Exception")): - with self.assertRaises(Exception): - await self.pte.stop_server() + self.pte.connections = {1: (mock.AsyncMock(), mock.AsyncMock())} + self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) + with mock.patch.object(self.pte.connections[1][1], 'close', side_effect=Exception("Test Exception")): + await self.pte.close_connection(1) + self.assertTrue(self.pte.connections == {}) + self.assertTrue(self.pte.server is None) + # Test close_connection async def test_close_connection(self): @@ -377,4 +358,4 @@ async def test_close_connection_not_found(self): await self.pte.close_connection(9999) # 9999 is not in self.connections # Check if logger.info was called - self.pte.logger.info.assert_called_with("Endpoint TestEndpoint: Private tasks for 9999 not found") + self.pte.logger.info.assert_called_with("Endpoint TestEndpoint: Tasks for 9999 not found")