Skip to content

Commit

Permalink
Working on Tunneling network stability, and encryption.
Browse files Browse the repository at this point in the history
  • Loading branch information
miroberts committed Jan 4, 2024
1 parent 8f46c97 commit 01bc395
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 272 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ keeper.txt
*.csv
Makefile
*.db
dr-logs
dr-logs
/.venv*
273 changes: 159 additions & 114 deletions keepercommander/commands/discoveryrotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import logging
import os.path
import queue
import socket
import sys
import threading
import time
Expand All @@ -38,13 +37,12 @@
from .pam.pam_dto import GatewayActionGatewayInfo, GatewayActionDiscoverInputs, GatewayActionDiscover, \
GatewayActionRotate, \
GatewayActionRotateInputs, GatewayAction, GatewayActionJobInfoInputs, \
GatewayActionJobInfo, GatewayActionJobCancel, GatewayActionWebRTCSession
GatewayActionJobInfo, GatewayActionJobCancel
from .pam.router_helper import router_send_action_to_gateway, print_router_response, \
router_get_connected_gateways, router_set_record_rotation_information, router_get_rotation_schedules, \
get_router_url, router_get_relay_access_creds
get_router_url
from .record_edit import RecordEditMixin
from .tunnel.port_forward.endpoint import establish_symmetric_key, tunnel_encrypt, WebRTCConnection, tunnel_decrypt, \
TunnelEntrance
from .tunnel.port_forward.endpoint import establish_symmetric_key, WebRTCConnection, TunnelEntrance, READ_TIMEOUT
from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades
from ..display import bcolors
from ..error import CommandError
Expand Down Expand Up @@ -1577,30 +1575,37 @@ def gather_tabel_row_data(thread):
minutes = 0
seconds = 0

if thread.get('started'):
run_time = datetime.now() - thread.get('started')
hours, remainder = divmod(run_time.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
entrance = thread.get('entrance')
#
# row.append(f"{thread.get('name', '')}")
row.append(f"{bcolors.OKBLUE}{thread.get('convo_id', '')}{bcolors.ENDC}")
row.append(f"{thread.get('host')}" if thread.get('host') else '')
row.append(f"{thread.get('host', '')}")

if not thread.get('entrance'):
if entrance is not None and entrance.print_ready_event.is_set():
if thread.get('started'):
run_time = datetime.now() - thread.get('started')
hours, remainder = divmod(run_time.seconds, 3600)
minutes, seconds = divmod(remainder, 60)

row.append(
f"{bcolors.OKBLUE}{entrance._port}{bcolors.ENDC}"
)
else:
row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}")
row.append(f"{thread.get('record_uid', '')}")
if entrance is not None and entrance.print_ready_event.is_set():
text_line = ""
if run_time:
if run_time.days == 1:
text_line += f"{run_time.days} day "
elif run_time.days > 1:
text_line += f"{run_time.days} days "
text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else ''
text_line += f"{minutes} min "
text_line += f"{seconds} sec"
row.append(text_line)
else:
row.append(f"{bcolors.OKBLUE}{thread.get('entrance')._port}{bcolors.ENDC}" if thread.get('entrance') else '')
row.append(f"{thread.get('record_uid')}" if thread.get('record_uid') else '')
text_line = ""
if run_time:
if run_time.days == 1:
text_line += f"{run_time.days} day "
elif run_time.days > 1:
text_line += f"{run_time.days} days "
text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else ''
text_line += f"{minutes} min "
text_line += f"{seconds} sec"
row.append(text_line)
row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}")
return row

if not params.tunnel_threads:
Expand All @@ -1618,6 +1623,26 @@ def gather_tabel_row_data(thread):
dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None)


def clean_up_tunnel(params, convo_id):
tunnel_data = params.tunnel_threads.get(convo_id)
if tunnel_data:
kill_server_event = tunnel_data.get("kill_server_event")
if kill_server_event:
if not kill_server_event.is_set():
kill_server_event.set()
# whatever the read timeout is, wait for 2 seconds more
time.sleep(READ_TIMEOUT + 2)
p = tunnel_data.get("process", None)
if p and p.is_alive():
p.join()
if params.tunnel_threads.get(convo_id):
del params.tunnel_threads[convo_id]
if params.tunnel_threads_queue.get(convo_id):
del params.tunnel_threads_queue[convo_id]
else:
print(f"{bcolors.WARNING}No tunnel data found to remove for {convo_id}{bcolors.ENDC}")


class PAMTunnelStopCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-stop-command')
pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID')
Expand All @@ -1632,16 +1657,8 @@ def execute(self, params, **kwargs):

tunnel_data = params.tunnel_threads.get(convo_id, None)
if not tunnel_data:
raise CommandError('tunnel stop', f"No tunnel data found for {convo_id}")

connect_task = tunnel_data.get("connect_task", None)
if connect_task:
connect_task.cancel()
count = 0
while params.tunnel_threads.get(convo_id) and count < 10:
count += .1
time.sleep(.1)

raise CommandError('tunnel stop', f"No tunnel data to remove found for {convo_id}")
clean_up_tunnel(params, convo_id)
return


Expand All @@ -1656,6 +1673,8 @@ def execute(self, params, **kwargs):
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel tail', '"uid" argument is required')
if convo_id not in params.tunnel_threads:
raise CommandError('tunnel tail', f"Tunnel UID {convo_id} not found")

log_queue = params.tunnel_threads_queue.get(convo_id)

Expand Down Expand Up @@ -1791,90 +1810,58 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
# Get symmetric key
symmetric_key = establish_symmetric_key(client_private_key_pem, gateway_public_key)

response = router_get_relay_access_creds(params=params)

# Set up the pc
print_ready_event = asyncio.Event()
pc = WebRTCConnection(endpoint_name=convo_id, print_ready_event=print_ready_event,
username=response.username, password=response.password, logger=logger)
kill_server_event = asyncio.Event()
pc = WebRTCConnection(endpoint_name=convo_id, params=params, record_uid=record_uid, gateway_uid=gateway_uid,
symmetric_key=symmetric_key, print_ready_event=print_ready_event,
kill_server_event=kill_server_event, logger=logger)

# make webRTC sdp offer
try:
offer = await pc.make_offer()
except socket.gaierror:
print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}")
return
except Exception as e:
raise CommandError('tunnel start', f'Error making WebRTC offer: {e}')
encrypted_offer = tunnel_encrypt(symmetric_key, offer)
logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY")

'''
'inputs': {
'conversationType': ['tunnel', 'guacd']
'kind': ['start', 'disconnect'],
'recordUid': record_uid, <-- this is the record UID of the PAM resource record
with Network information
'listenerName': NAME OF LISTENER, <-- Used in logging (not required)
'offer': encrypted_WebRTC_sdp_offer, <-- WebRTC SDP offer encrypted with symmetric key
'allow_control': True, <-- only for guacd, False = readonly session (default True)
'guacamole_client_id: guacamole_client_id, <-- only for guacd, Connect to an existing guacd session
'userRecordUid': userRecordUid, <-- only for guacd, User record UID to connect for session
}
'''
# TODO create objects for WebRTC inputs
router_response = router_send_action_to_gateway(
params=params,
gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': convo_id, "recordUid": record_uid,
"offer": encrypted_offer, 'kind': 'start',
'conversationType': 'tunnel'}),
message_type=pam_pb2.CMT_GENERAL,
is_streaming=False,
destination_gateway_uid_str=gateway_uid,
gateway_timeout=30000
)
if not router_response:
return
gateway_response = router_response.get('response', {})
if not gateway_response:
raise Exception(f"Error getting response from the Gateway: {router_response}")
try:
payload = json.loads(gateway_response.get('payload', None))
if not payload:
raise Exception(f"Error getting payload from the Gateway response: {gateway_response}")
await pc.signal_channel('start')
except Exception as e:
raise Exception(f"Error getting payload from the Gateway response: {e}")

if payload.get('is_ok', False) is False or payload.get('progress_status') == 'Error':
raise Exception(f"Error getting payload from the Gateway response: {payload.get('data')}")

encrypted_answer = payload.get('data', None)
if not encrypted_answer:
raise Exception(f"Error getting data from the Gateway response payload: {payload}")

# decrypt the sdp answer
answer = tunnel_decrypt(symmetric_key, encrypted_answer)
await pc.accept_answer(answer)
CommandError('tunnel start', f"{e}")

logger.debug("starting private tunnel")

private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc,
print_ready_event=print_ready_event, logger=logger,
connect_task = params.tunnel_threads[convo_id].get("connect_task", None))
connect_task=params.tunnel_threads[convo_id].get("connect_task", None),
kill_server_event=kill_server_event)

t1 = asyncio.create_task(private_tunnel.start_server())
params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel})
params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel,
"kill_server_event": kill_server_event})

logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------")
await asyncio.gather(t1, private_tunnel.reader_task)
try:
await asyncio.gather(t1, private_tunnel.reader_task)
except asyncio.CancelledError:
pass
finally:
logger.debug("--> STOP LISTENING FOR MESSAGES FROM GATEWAY --------")

def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key):

def custom_exception_handler(loop, context):
# Check if the exception is present in the context
if "exception" in context:
exception = context["exception"]
if isinstance(exception, ConnectionError):
# Handle only ConnectionError
logging.debug(f"Caught ConnectionError in asyncio: {exception}")
else:
# Log the default message if no exception is found
logging.error(context["message"])

loop = None
try:
# Create a new asyncio event loop and set the custom exception handler
loop = asyncio.new_event_loop()
params.tunnel_threads[convo_id].update({"loop": loop})
asyncio.set_event_loop(loop)
loop.set_exception_handler(custom_exception_handler)
output_queue = queue.Queue(maxsize=500)
params.tunnel_threads_queue[convo_id] = output_queue
# Create a Task from the coroutine
Expand All @@ -1899,6 +1886,8 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
pass
except SocketNotConnectedException as es:
print(f"{bcolors.FAIL}Socket not connected exception in connection {convo_id}: {es}{bcolors.ENDC}")
except KeyboardInterrupt:
print(f"{bcolors.OKBLUE}Exiting: {convo_id}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}")
finally:
Expand Down Expand Up @@ -1932,13 +1921,14 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}")

def execute(self, params, **kwargs):
version = [3, 11, 0]
# Check for Python 3.11+
# https://pypi.org/project/aiortc/
# aiortc Requires: Python >=3.8
version = [3, 8, 0]
major_version = sys.version_info.major
minor_version = sys.version_info.minor
micro_version = sys.version_info.micro

if (major_version, minor_version, micro_version) <= (version[0], version[1], version[2]):
if (major_version, minor_version, micro_version) < (version[0], version[1], version[2]):
print(f"{bcolors.FAIL}This code requires Python {version[0]}.{version[1]}.{version[2]} or higher. "
f"You are using {major_version}.{minor_version}.{micro_version}.{bcolors.ENDC}")
return
Expand All @@ -1952,6 +1942,10 @@ def execute(self, params, **kwargs):

gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils)

if not gateway_public_key_bytes:
print(f"{bcolors.FAIL}Could not retrieve public key for gateway {gateway_uid}{bcolors.ENDC}")
return

api.sync_down(params)
record = vault.KeeperRecord.load(params, record_uid)
if not isinstance(record, vault.TypedRecord):
Expand Down Expand Up @@ -1991,22 +1985,73 @@ def execute(self, params, **kwargs):
t.daemon = True
t.start()

params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid})
if not params.tunnel_threads.get(convo_id):
params.tunnel_threads[convo_id] = {"convo_id": convo_id, "thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid}
else:
params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid})
count = 0
wait_time = 120
entrance = None
# run once
while count == 0:
while count < wait_time:
if params.tunnel_threads.get(convo_id):
entrance = params.tunnel_threads[convo_id].get("entrance", None)
if entrance:
break
else:
break
count += .1
time.sleep(.1)

def print_fail():
fail_dynamic_length = len("| Endpoint : failed to start..") + len(convo_id)

clean_up_tunnel(params, convo_id)
time.sleep(.5)
# Dashed line adjusted to the length of the middle line
fail_dashed_line = '+' + '-' * fail_dynamic_length + '+'
print(f'\n{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}')
print(f'{bcolors.FAIL}| Endpoint {convo_id}{bcolors.ENDC} failed to start..')
print(f'{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}\n')

if entrance is not None:
while not entrance.print_ready_event.is_set() and count < wait_time * 2:
count += .1
time.sleep(.1)

if entrance.print_ready_event.is_set():
# Sleep a little bit to print out last
time.sleep(.5)
host = host + ":" if host else ''
# Total length of the dynamic parts (endpoint name, host, and port)
dynamic_length = \
(len("| Endpoint : Listening on port: ") + len(convo_id) + len(host) + len(str(port)))

# Dashed line adjusted to the length of the middle line
dashed_line = '+' + '-' * dynamic_length + '+'

# Print statements
print(f'\n{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{convo_id}{bcolors.ENDC}'
f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}'
f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel tail ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel stop ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}\n')
else:
print_fail()
else:
print_fail()

while count < wait_time and not params.tunnel_threads[convo_id].get("entrance"):
count += .1
time.sleep(.1)

if count >= wait_time:
# There could be an error that happened in the thread return
return

entrance = params.tunnel_threads[convo_id].get("entrance")
while not entrance.print_ready_event.is_set() and count < wait_time * 2:
count += .1
time.sleep(.1)
# After it is set to True the print waits for .5 seconds then prints so we need to wait a little longer
time.sleep(1)
Loading

0 comments on commit 01bc395

Please sign in to comment.