Skip to content

Commit

Permalink
Tunneling add support for socket eof across the tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
miroberts committed Jan 11, 2024
1 parent 77525a9 commit 0cd3e75
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions keepercommander/commands/tunnel/port_forward/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ControlMessage(enum.IntEnum):
OpenConnection = 101
CloseConnection = 102
ConnectionOpened = 103
SendEOF = 104


def generate_random_bytes(pass_length=RANDOM_LENGTH): # type: (int) -> bytes
Expand Down Expand Up @@ -463,6 +464,7 @@ def __init__(self,
self.pc = pc
self.print_ready_event = print_ready_event
self.connect_task = connect_task
self.eof_sent = False

@property
def port(self):
Expand Down Expand Up @@ -579,6 +581,14 @@ async def process_control_message(self, message_no, data): # type: (ControlMess
self.logger.error(f"Endpoint {self.endpoint_name}: Error in forwarding data task: {e}")
else:
self.logger.error(f"Endpoint {self.endpoint_name}: Invalid open connection message")
elif message_no == ControlMessage.SendEOF:
if len(data) >= CONNECTION_NO_LENGTH:
con_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big')
if con_no in self.connections:
self.logger.debug(f'Endpoint {self.endpoint_name}: Sending EOF to {con_no}')
self.connections[con_no].writer.write_eof()
else:
self.logger.error(f'Endpoint {self.endpoint_name}: Connection for EOF {con_no} not found')
else:
self.logger.warning(f'Endpoint {self.endpoint_name} Unknown tunnel control message: {message_no}')

Expand Down Expand Up @@ -726,10 +736,15 @@ async def forward_data_to_tunnel(self, con_no):
break
if isinstance(data, bytes):
if c.reader.at_eof() and len(data) == 0:
if not self.eof_sent:
await self.send_control_message(ControlMessage.SendEOF,
int_to_bytes(con_no, CONNECTION_NO_LENGTH))
self.eof_sent = True
# Yield control back to the event loop for other tasks to execute
await asyncio.sleep(0)
continue
else:
self.eof_sent = False
buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big')
buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR
await self.send_to_web_rtc(buffer)
Expand Down Expand Up @@ -849,7 +864,8 @@ async def stop_server(self):
self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception closing data channel {ex}')

try:
self.connect_task.cancel()
if self.connect_task is not None:
self.connect_task.cancel()
finally:
self.closing = True
self.logger.info(f"Endpoint {self.endpoint_name}: Tunnel stopped")
Expand All @@ -873,7 +889,8 @@ async def close_connection(self, connection_no):

if connection_no in self.connections:
try:
self.connections[connection_no].to_tunnel_task.cancel()
if self.connections[connection_no].to_tunnel_task is not None:
self.connections[connection_no].to_tunnel_task.cancel()
except Exception as ex:
self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling tasks {ex}')
del self.connections[connection_no]
Expand Down

0 comments on commit 0cd3e75

Please sign in to comment.