diff --git a/keepercommander/commands/discover/__init__.py b/keepercommander/commands/discover/__init__.py new file mode 100644 index 000000000..7927be248 --- /dev/null +++ b/keepercommander/commands/discover/__init__.py @@ -0,0 +1,251 @@ +from __future__ import annotations +import logging +from ..base import Command +from ..pam.config_facades import PamConfigurationRecordFacade +from ..pam import gateway_helper +from ..pam.router_helper import get_response_payload +from ..pam.gateway_helper import get_all_gateways +from ..ksm import KSMCommand +from ... import utils, vault_extensions +from ... import vault +from ...proto import APIRequest_pb2 +from ...crypto import encrypt_aes_v2, decrypt_aes_v2 +from ...display import bcolors +from discovery_common.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +import json +import base64 + +from typing import List, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + from ...vault import KeeperRecord, ApplicationRecord + from ...proto import pam_pb2 + + +class GatewayContext: + def __init__(self, configuration: KeeperRecord, facade: PamConfigurationRecordFacade, + gateway: pam_pb2.PAMController, application: ApplicationRecord): + self.configuration = configuration + self.facade = facade + self.gateway = gateway + self.application = application + self._shared_folders = None + + @staticmethod + def from_configuration_uid(params: KeeperParams, configuration_uid: str): + + configuration_record = vault.KeeperRecord.load(params, configuration_uid) + if not isinstance(configuration_record, vault.TypedRecord): + print(f'{bcolors.FAIL}PAM Configuration [{configuration_uid}] is not available.{bcolors.ENDC}') + return + + configuration_facade = PamConfigurationRecordFacade() + configuration_facade.record = configuration_record + + gateway_uid = configuration_facade.controller_uid + gateway = next((x for x in gateway_helper.get_all_gateways(params) + if utils.base64_url_encode(x.controllerUid) == gateway_uid), + None) + + if gateway is None: + return + + application_id = utils.base64_url_encode(gateway.applicationUid) + application = KSMCommand.get_app_record(params, application_id) + + return GatewayContext( + configuration=configuration_record, + facade=configuration_facade, + gateway=gateway, + application=application + ) + + @staticmethod + def from_gateway(params: KeeperParams, gateway: str): + # Get all the PAM configuration records + configuration_records = list(vault_extensions.find_records(params, "pam.*Configuration")) + if len(configuration_records) == 0: + print(f"{bcolors.FAIL}Cannot find any PAM configuration records in the Vault{bcolors.ENDC}") + + all_gateways = get_all_gateways(params) + + for record in configuration_records: + + logging.debug(f"checking configuration record {record.title}") + + # Load the configuration record and get the gateway_uid from the facade. + configuration_record = vault.KeeperRecord.load(params, record.record_uid) + configuration_facade = PamConfigurationRecordFacade() + configuration_facade.record = configuration_record + + configuration_gateway_uid = configuration_facade.controller_uid + if configuration_gateway_uid is None: + logging.debug(f"configuration {configuration_record.title} does not have a gateway set, skipping.") + continue + + # Get the gateway for this configuration + found_gateway = next((x for x in all_gateways if utils.base64_url_encode(x.controllerUid) == + configuration_gateway_uid), None) + if found_gateway is None: + logging.debug(f"cannot find gateway for configuration {configuration_record.title}, skipping.") + continue + + application_id = utils.base64_url_encode(found_gateway.applicationUid) + application = KSMCommand.get_app_record(params, application_id) + if application is None: + logging.debug(f"cannot find application for gateway {gateway}, skipping.") + + if (utils.base64_url_encode(found_gateway.controllerUid) == gateway or + found_gateway.controllerName.lower() == gateway.lower()): + return GatewayContext( + configuration=configuration_record, + facade=configuration_facade, + gateway=found_gateway, + application=application + ) + + return None + + @property + def gateway_uid(self) -> str: + return utils.base64_url_encode(self.gateway.controllerUid) + + @property + def configuration_uid(self) -> str: + return self.configuration.record_uid + + @property + def gateway_name(self) -> str: + return self.gateway.controllerName + + @property + def default_shared_folder_uid(self) -> str: + return self.facade.folder_uid + + def is_gateway(self, request_gateway: str) -> bool: + if request_gateway is None or self.gateway_name is None: + return False + return (request_gateway == utils.base64_url_encode(self.gateway.controllerUid) or + request_gateway.lower() == self.gateway_name.lower()) + + def get_shared_folders(self, params: KeeperParams) -> List[dict]: + if self._shared_folders is None: + self._shared_folders = [] + application_uid = utils.base64_url_encode(self.gateway.applicationUid) + app_info = KSMCommand.get_app_info(params, application_uid) + for info in app_info: + if info.shares is None: + continue + for shared in info.shares: + uid_str = utils.base64_url_encode(shared.secretUid) + shared_type = APIRequest_pb2.ApplicationShareType.Name(shared.shareType) + if shared_type == 'SHARE_TYPE_FOLDER': + if uid_str not in params.shared_folder_cache: + continue + cached_shared_folder = params.shared_folder_cache[uid_str] + self._shared_folders.append({ + "uid": uid_str, + "name": cached_shared_folder.get('name_unencrypted'), + "folder": cached_shared_folder + }) + return self._shared_folders + + def decrypt(self, cipher_base64: bytes) -> dict: + ciphertext = base64.b64decode(cipher_base64) + return json.loads(decrypt_aes_v2(ciphertext, self.configuration.record_key)) + + def encrypt(self, data: dict) -> str: + json_data = json.dumps(data) + ciphertext = encrypt_aes_v2(json_data.encode(), self.configuration.record_key) + return base64.b64encode(ciphertext).decode() + + def encrypt_str(self, data: Union[bytes, str]) -> str: + if isinstance(data, str): + data = data.encode() + ciphertext = encrypt_aes_v2(data, self.configuration.record_key) + return base64.b64encode(ciphertext).decode() + + +class PAMGatewayActionDiscoverCommandBase(Command): + + """ + The discover command base. + + Contains static methods to get the configuration record, get and update the discovery store. These are methods + used by multiple discover actions. + """ + + # If the discovery data field does not exist, or the field contains no values, use the template to init the + # field. + + STORE_LABEL = "discoveryKey" + FIELD_MAPPING = { + "pamHostname": { + "type": "dict", + "field_input": [ + {"key": "hostName", "prompt": "Hostname"}, + {"key": "port", "prompt": "Port"} + ], + "field_format": [ + {"key": "hostName", "label": "Hostname"}, + {"key": "port", "label": "Port"}, + ] + }, + "alternativeIPs": { + "type": "csv", + }, + "privatePEMKey": { + "type": "multiline", + }, + "operatingSystem": { + "type": "choice", + "values": ["linux", "macos", "windows"] + } + } + + type_name_map = { + PAM_USER: "PAM Users", + PAM_MACHINE: "PAM Machines", + PAM_DATABASE: "PAM Databases", + PAM_DIRECTORY: "PAM Directories", + } + + @staticmethod + def get_response_data(router_response: dict) -> Optional[dict]: + + if router_response is None: + return None + + response = router_response.get("response") + logging.debug(f"Router Response: {response}") + payload = get_response_payload(router_response) + return payload.get("data") + + @staticmethod + def _gr(msg): + return f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}" + + @staticmethod + def _bl(msg): + return f"{bcolors.OKBLUE}{msg}{bcolors.ENDC}" + + @staticmethod + def _h(msg): + return f"{bcolors.HEADER}{msg}{bcolors.ENDC}" + + @staticmethod + def _b(msg): + return f"{bcolors.BOLD}{msg}{bcolors.ENDC}" + + @staticmethod + def _f(msg): + return f"{bcolors.FAIL}{msg}{bcolors.ENDC}" + + @staticmethod + def _p(msg): + return msg + + @staticmethod + def _n(record_type): + return PAMGatewayActionDiscoverCommandBase.type_name_map.get(record_type, "PAM Configuration") diff --git a/keepercommander/commands/discover/debug.py b/keepercommander/commands/discover/debug.py new file mode 100644 index 000000000..c5972a42d --- /dev/null +++ b/keepercommander/commands/discover/debug.py @@ -0,0 +1,268 @@ +from __future__ import annotations +import argparse +import os +from . import PAMGatewayActionDiscoverCommandBase +from ...display import bcolors +from ... import vault +from discovery_common.infrastructure import Infrastructure +from discovery_common.record_link import RecordLink +from discovery_common.types import UserAcl, DiscoveryObject +from keeper_dag import EdgeType +from importlib.metadata import version +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMGatewayActionDiscoverDebugCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-command-debug') + + # The record to base everything on. + parser.add_argument('--record-uid', '-i', required=False, dest='record_uid', action='store', + help='Keeper record UID.') + + # What to do + parser.add_argument('--info', required=False, dest='info_flag', + action='store_true', help='Display information about the record.') + parser.add_argument('--belongs-to', required=False, dest='belongs_to_flag', + action='store_true', help='Connect the record to the parent record.') + parser.add_argument('--disconnect', required=False, dest='disconnect_flag', + action='store_true', help='Disconnect the record to the parent record.') + parser.add_argument('--render', required=False, dest='render_flag', action='store_true', + help='Render graphs.') + parser.add_argument('--version', required=False, dest='version_flag', action='store_true', + help='Get module versions.') + + # For --belongs-to and --disconnect + parser.add_argument('--parent-record-uid', '-p', required=False, dest='parent_record_uid', + action='store', help='The parent record UID.') + + # For the info command + parser.add_argument('--render-all-edges', required=False, dest='render_all_edges', + action='store_false', help='Render graphs.') + parser.add_argument('--graph-dir', required=False, dest='graph_dir', action='store', + help='Directory to save graphs.') + parser.add_argument('--infra-graph-name', required=False, dest='infra_name', action='store', + default="infra_graph", help='Infrastructure graph name.') + parser.add_argument('--rl-graph-name', required=False, dest='rl_name', action='store', + default="record_linking_graph", help='Record linking graph name.') + parser.add_argument('--graph-type', '-gt', required=False, choices=['dot', 'twopi', 'patchwork'], + dest='graph_type', default="dot", action='store', help='The render graph type.') + + def get_parser(self): + return PAMGatewayActionDiscoverDebugCommand.parser + + @staticmethod + def _versions(): + print("") + print(f"{bcolors.BOLD}keeper-dag version:{bcolors.ENDC} {version('keeper-dag')}") + print(f"{bcolors.BOLD}discovery-common version:{bcolors.ENDC} {version('discovery-common')}") + print("") + + @staticmethod + def _show_info(params: KeeperParams, configuration_record: TypedRecord, record: TypedRecord): + + infra = Infrastructure(record=configuration_record, params=params) + record_link = RecordLink(record=configuration_record, params=params) + + print("") + print(f"{bcolors.BOLD}Configuration UID:{bcolors.ENDC} {configuration_record.record_uid}") + print(f"{bcolors.BOLD}Configuration Key Bytes Hex:{bcolors.ENDC} {configuration_record.record_key.hex()}") + print("") + try: + discovery_vertices = infra.dag.search_content({"record_uid": record.record_uid}) + if len(discovery_vertices) > 0: + + if len(discovery_vertices) > 1: + print(f"{bcolors.FAIL}Found multiple vertices with the record UID of " + f"{record.record_uid}{bcolors.ENDC}") + for vertex in discovery_vertices: + print(f" * Infrastructure Vertex UID: {vertex.uid}") + print("") + + discovery_vertex = discovery_vertices[0] + content = DiscoveryObject.get_discovery_object(discovery_vertex) + + print(f"{bcolors.HEADER}Discovery Object Information{bcolors.ENDC}") + print(f"Vertex UID: {content.uid}") + print(f"Record UID: {content.record_uid}") + print(f"Parent Record UID: {content.parent_record_uid}") + print(f"Shared Folder UID: {content.shared_folder_uid}") + print(f"Record Type: {content.record_type}") + print(f"Object Type: {content.object_type_value}") + print(f"Ignore Object: {content.ignore_object}") + print(f"Rule Engine Result: {content.action_rules_result}") + print(f"Discovery ID: {content.id}") + print(f"Discovery Name: {content.name}") + print(f"Discovery Title: {content.title}") + print(f"Discovery Description: {content.description}") + print(f"Discovery Notes:") + for note in content.notes: + print(f" * {note}") + if content.error is not None: + print(f"{bcolors.FAIL}Error: {content.error}{bcolors.ENDC}") + if content.stacktrace is not None: + print(f"{bcolors.FAIL}Stack Trace:{bcolors.ENDC}") + print(f"{bcolors.FAIL}{content.stacktrace}{bcolors.ENDC}") + print("") + print(f"{bcolors.HEADER}Record Type Specifics{bcolors.ENDC}") + + item_dict = content.item + for k, v in item_dict.__dict__.items(): + print(f"{k} = {v}") + + print("") + print(f"{bcolors.HEADER}Belongs To Vertices{bcolors.ENDC}") + vertices = discovery_vertex.belongs_to_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = discovery_vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + + if len(vertices) == 0: + print(f"{bcolors.FAIL} Does not belong to anyone{bcolors.ENDC}") + + print("") + print(f"{bcolors.HEADER}Vertices Belonging To{bcolors.ENDC}") + vertices = discovery_vertex.has_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(discovery_vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + if len(vertices) == 0: + print(f" Does not have any children.") + + print("") + else: + print(f"{bcolors.FAIL}Could not find infrastructure vertex.{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Could not get information on infrastructure: {err}{bcolors.ENDC}") + + record_vertex = record_link.dag.get_vertex(record.record_uid) + if record_vertex is not None: + print(f"{bcolors.HEADER}Record Linking{bcolors.ENDC}") + for parent_vertex in record_vertex.belongs_to_vertices(): + + description = "Unknown" + discovery_vertices = infra.dag.search_content({"record_uid": parent_vertex.uid}) + if len(discovery_vertices) > 0: + content = DiscoveryObject.get_discovery_object(discovery_vertices[0]) + description = content.description + acl_edge = record_vertex.get_edge(parent_vertex, EdgeType.ACL) + if acl_edge is not None: + acl_content = acl_edge.content_as_object(UserAcl) + print(f" * ACL to {description} ({parent_vertex.uid})") + print(f" . belongs_to = {acl_content.belongs_to}") + print(f" . is_admin = {acl_content.is_admin}") + link_edge = record_vertex.get_edge(parent_vertex, EdgeType.LINK) + if link_edge is not None: + print(f" * LINK to {description} ({parent_vertex.uid})") + else: + print(f"{bcolors.FAIL}Cannot find in record linking.{bcolors.ENDC}") + + @staticmethod + def _render(params: KeeperParams, + configuration_record: TypedRecord, + infra_name: str = "infra_name", rl_name: str = "record_link_graph", + graph_type: str = "dot", graph_dir: str = None, render_all_edges: bool = False): + + if graph_dir is None: + graph_dir = os.environ.get("HOME", os.environ.get("PROFILENAME", ".")) + + print(f"Loading graphs for controller {configuration_record.record_uid}.") + + infra = Infrastructure(record=configuration_record, params=params) + record_link = RecordLink(record=configuration_record, params=params) + + print("") + try: + filename = os.path.join(graph_dir, f"{infra_name}.dot") + infra.to_dot( + graph_type=graph_type, + show_only_active_vertices=False, + show_only_active_edges=render_all_edges + ).render(filename) + print(f"Infrastructure graph rendered to {filename}") + except Exception as err: + print(f"{bcolors.FAIL}Could not generate infrastructure graph: {err}{bcolors.ENDC}") + raise err + + try: + filename = os.path.join(graph_dir, f"{rl_name}.dot") + record_link.to_dot( + graph_type=graph_type, + show_only_active_vertices=False, + show_only_active_edges=render_all_edges + ).render(filename) + print(f"Record linking graph rendered to {filename}") + except Exception as err: + print(f"{bcolors.FAIL}Could not generate record linking graph: {err}{bcolors.ENDC}") + raise err + + filename = os.path.join(graph_dir, f"infra_raw.dot") + with open(filename, "w") as fh: + fh.write(str(infra.dag.to_dot())) + fh.close() + + filename = os.path.join(graph_dir, f"record_linking_raw.dot") + with open(filename, "w") as fh: + fh.write(str(record_link.dag.to_dot())) + fh.close() + + def execute(self, params, **kwargs): + + info_flag = kwargs.get("info_flag", False) + belongs_to_flag = kwargs.get("belongs_to_flag", False) + disconnect_flag = kwargs.get("disconnect_flag", False) + render_flag = kwargs.get("render_flag", False) + version_flag = kwargs.get("version_flag", False) + + record_uid = kwargs.get("record_uid") + configuration_record = None + if record_uid is not None: + record = vault.KeeperRecord.load(params, record_uid) # type: Optional[TypedRecord] + if record is None: + print(f"{bcolors.FAIL}Record does not exists.{bcolors.ENDC}") + return + + configuration_record = record + if record.record_type in ["pamUser", "pamMachine", "pamDatabase", "pamDirectory"]: + record_rotation = params.record_rotation_cache.get(record_uid) + if record_rotation is None: + print(f"{bcolors.FAIL}Record does not have rotation settings.{bcolors.ENDC}") + return + + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None: + print(f"{bcolors.FAIL}Record does not have the PAM Configuration set.{bcolors.ENDC}") + return + + configuration_record = vault.KeeperRecord.load(params, controller_uid) # type: Optional[TypedRecord] + + if version_flag is True: + self._versions() + if render_flag is True: + self._render( + params=params, + configuration_record=configuration_record, + infra_name=kwargs.get("infra_name"), + rl_name=kwargs.get("rl_name"), + graph_type=kwargs.get("graph_type"), + graph_dir=kwargs.get("graph_dir"), + render_all_edges=kwargs.get("render_all_edges"), + ) + if info_flag is True: + self._show_info( + params=params, + configuration_record=configuration_record, + record=record + ) + diff --git a/keepercommander/commands/discover/job_remove.py b/keepercommander/commands/discover/job_remove.py new file mode 100644 index 000000000..c065abee0 --- /dev/null +++ b/keepercommander/commands/discover/job_remove.py @@ -0,0 +1,75 @@ +from __future__ import annotations +import argparse +import logging +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ..pam.pam_dto import GatewayActionDiscoverJobRemoveInputs, GatewayActionDiscoverJobRemove, GatewayAction +from ...proto import pam_pb2 +from ..pam.router_helper import router_send_action_to_gateway, router_get_connected_gateways +from ... import vault_extensions +from ...display import bcolors +from discovery_common.jobs import Jobs + + +class PAMGatewayActionDiscoverJobRemoveCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-command-process') + parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', + help='Discovery job id.') + + def get_parser(self): + return PAMGatewayActionDiscoverJobRemoveCommand.parser + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + job_id = kwargs.get("job_id") + + # Get all the PAM configuration records + configuration_records = list(vault_extensions.find_records(params, "pam.*Configuration")) + + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid(params, configuration_record.record_uid) + if gateway_context is None: + continue + + jobs = Jobs(record=configuration_record, params=params) + job_item = jobs.get_job(job_id) + if job_item is not None: + + try: + # First, cancel the running discovery job if it is running. + logging.debug("cancel job on the gateway, if running") + action_inputs = GatewayActionDiscoverJobRemoveInputs( + configuration_uid=gateway_context.configuration_uid, + job_id=job_id + ) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionDiscoverJobRemove( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = self.get_response_data(router_response) + if data is None: + raise Exception("The router returned a failure.") + elif data.get("success") is False: + error = data.get("error") + raise Exception(f"Discovery job was not removed: {error}") + except Exception as err: + logging.debug(f"gateway return error removing discovery job: {err}") + + jobs.cancel(job_id) + + print(f"{bcolors.OKGREEN}Discovery job has been removed or cancelled.{bcolors.ENDC}") + return + + print(f'{bcolors.FAIL}Discovery job not found. Cannot get remove the job.{bcolors.ENDC}') + return diff --git a/keepercommander/commands/discover/job_start.py b/keepercommander/commands/discover/job_start.py new file mode 100644 index 000000000..42c28ef71 --- /dev/null +++ b/keepercommander/commands/discover/job_start.py @@ -0,0 +1,236 @@ +from __future__ import annotations +import argparse +import logging +import json +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from .job_status import PAMGatewayActionDiscoverJobStatusCommand +from ..pam.router_helper import router_send_action_to_gateway, print_router_response, router_get_connected_gateways +from ..pam.user_facade import PamUserRecordFacade +from ..pam.pam_dto import GatewayActionDiscoverJobStartInputs, GatewayActionDiscoverJobStart, GatewayAction +from ... import vault_extensions +from ... import vault +from ...proto import pam_pb2 +from ...display import bcolors +from discovery_common.jobs import Jobs +from discovery_common.types import CredentialBase +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + + +class PAMGatewayActionDiscoverJobStartCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-start-command') + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--resource', '-r', required=False, dest='resource_uid', action='store', + help='UID of the resource record. Set to discover specific resource.') + parser.add_argument('--lang', required=False, dest='language', action='store', default="en", + help='Language') + parser.add_argument('--include-machine-dir-users', required=False, dest='include_machine_dir_users', + action='store_false', default=True, help='Include directory users found on the machine.') + parser.add_argument('--inc-azure-aadds', required=False, dest='include_azure_aadds', + action='store_true', help='Include Azure Active Directory Domain Service.') + parser.add_argument('--skip-rules', required=False, dest='skip_rules', + action='store_true', help='Skip running the rule engine.') + parser.add_argument('--skip-machines', required=False, dest='skip_machines', + action='store_true', help='Skip discovering machines.') + parser.add_argument('--skip-databases', required=False, dest='skip_databases', + action='store_true', help='Skip discovering databases.') + parser.add_argument('--skip-directories', required=False, dest='skip_directories', + action='store_true', help='Skip discovering directories.') + parser.add_argument('--skip-cloud-users', required=False, dest='skip_cloud_users', + action='store_true', help='Skip discovering cloud users.') + parser.add_argument('--cred', required=False, dest='credentials', + action='append', help='List resource credentials.') + parser.add_argument('--cred-file', required=False, dest='credential_file', + action='store', help='A JSON file containing list of credentials.') + + def get_parser(self): + return PAMGatewayActionDiscoverJobStartCommand.parser + + @staticmethod + def make_protobuf_user_map(params: KeeperParams, gateway_context: GatewayContext) -> List[dict]: + """ + Make a user map for PAM Users. + + The map is used to find existing records. + Since KSM cannot read the rotation settings using protobuf, + it cannot match a vault record to a discovered users. + This map will map a login/DN and parent UID to a record UID. + """ + + user_map = [] + for record in vault_extensions.find_records(params, record_type="pamUser"): + user_record = vault.KeeperRecord.load(params, record.record_uid) + user_facade = PamUserRecordFacade() + user_facade.record = user_record + + info = params.record_rotation_cache.get(user_record.record_uid) + if info is None: + continue + + # Make sure this user is part of this gateway. + if info.get("configuration_uid") != gateway_context.configuration_uid: + continue + + # If the user Admin Cred Record (i.e., parent) is blank, skip the mapping item + # This will be a UID string, not 16 bytes. + if info.get("resource_uid") is None or info.get("resource_uid") == "": + continue + + user_map.append({ + "user": user_facade.login if user_facade.login != "" else None, + "dn": user_facade.distinguishedName if user_facade.distinguishedName != "" else None, + "record_uid": user_record.record_uid, + "parent_record_uid": info.get("resource_uid") + }) + + logging.debug(f"found {len(user_map)} user map items") + + return user_map + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + # Load the configuration record and get the gateway_uid from the facade. + gateway = kwargs.get('gateway') + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + jobs = Jobs(record=gateway_context.configuration, params=params) + current_job_item = jobs.current_job + removed_prior_job = None + if current_job_item is not None: + if current_job_item.is_running is True: + print("") + print(f"{bcolors.FAIL}A discovery job is currently running. " + f"Cannot start another until it is finished.{bcolors.ENDC}") + print(f"To check the status, use the command " + f"'{bcolors.OKGREEN}pam action discover status{bcolors.ENDC}'.") + print(f"To stop and remove the current job, use the command " + f"'{bcolors.OKGREEN}pam action discover remove -j {current_job_item.job_id}'.") + return + + print(f"{bcolors.FAIL}An active discovery job exists for this gateway.{bcolors.ENDC}") + print("") + status = PAMGatewayActionDiscoverJobStatusCommand() + status.execute(params=params) + print("") + + yn = input("Do you wish to remove the active discovery job and run a new one [Y/N]> ").lower() + while True: + if yn[0] == "y": + jobs.cancel(current_job_item.job_id) + removed_prior_job = current_job_item.job_id + break + elif yn[0] == "n": + print(f"{bcolors.FAIL}Not starting a discovery job.{bcolors.ENDC}") + return + + # Get the credentials passed in via the command line + credentials = [] + creds = kwargs.get('credentials') + if creds is not None: + for cred in creds: + parts = cred.split("|") + c = CredentialBase() + for item in parts: + kv = item.split("=") + if len(kv) != 2: + print(f"{bcolors.FAIL}A '--cred' is invalid. It does not have a value.{bcolors.ENDC}") + return + if hasattr(c, kv[0]) is False: + print(f"{bcolors.FAIL}A '--cred' is invalid. The key '{kv[0]}' is invalid.{bcolors.ENDC}") + return + if hasattr(c, kv[1]) == "": + print(f"{bcolors.FAIL}A '--cred' is invalid. The value is blank.{bcolors.ENDC}") + return + setattr(c, kv[0], kv[1]) + credentials.append(c.model_dump()) + + # Get the credentials passed in via a credential file. + credential_files = kwargs.get('credential_file') + if credential_files is not None: + with open(credential_files, "r") as fh: + try: + creds = json.load(fh) + except FileNotFoundError: + print(f"{bcolors.FAIL}Could not find the file {credential_files}{bcolors.ENDC}") + return + except json.JSONDecoder: + print(f"{bcolors.FAIL}The file {credential_files} is not valid JSON.{bcolors.ENDC}") + return + except Exception as err: + print(f"{bcolors.FAIL}The JSON file {credential_files} could not be imported: {err}{bcolors.ENDC}") + return + + if isinstance(creds, list) is False: + print(f"{bcolors.FAIL}Credential file is invalid. Structure is not an array.{bcolors.ENDC}") + return + num = 1 + for obj in creds: + c = CredentialBase() + for key in obj: + if hasattr(c, key) is False: + print(f"{bcolors.FAIL}Object {num} has the invalid key {key}.{bcolors.ENDC}") + return + setattr(c, key, obj[key]) + credentials.append(c.model_dump()) + + action_inputs = GatewayActionDiscoverJobStartInputs( + configuration_uid=gateway_context.configuration_uid, + resource_uid=kwargs.get('resource_uid'), + user_map=gateway_context.encrypt( + self.make_protobuf_user_map( + params=params, + gateway_context=gateway_context + ) + ), + + shared_folder_uid=gateway_context.default_shared_folder_uid, + language=kwargs.get('language'), + + # Settings + include_machine_dir_users=kwargs.get('include_machine_dir_users', True), + include_azure_aadds=kwargs.get('include_azure_aadds', False), + skip_rules=kwargs.get('skip_rules', False), + skip_machines=kwargs.get('skip_machines', False), + skip_databases=kwargs.get('skip_databases', False), + skip_directories=kwargs.get('skip_directories', False), + skip_cloud_users=kwargs.get('skip_cloud_users', False), + credentials=credentials + ) + + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionDiscoverJobStart( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = self.get_response_data(router_response) + if data is None: + print(f"{bcolors.FAIL}The router returned a failure.{bcolors.ENDC}") + return + + if "has been queued" in data.get("Response", ""): + + if removed_prior_job is None: + print("The discovery job is currently running.") + else: + print(f"Active discovery job {removed_prior_job} has been removed and new discovery job is running.") + print(f"To check the status, use the command '{bcolors.OKGREEN}pam action discover status{bcolors.ENDC}'.") + print(f"To stop and remove the current job, use the command " + f"'{bcolors.OKGREEN}pam action discover remove -j '.") + else: + print_router_response(router_response, conversation_id) diff --git a/keepercommander/commands/discover/job_status.py b/keepercommander/commands/discover/job_status.py new file mode 100644 index 000000000..47bf532be --- /dev/null +++ b/keepercommander/commands/discover/job_status.py @@ -0,0 +1,307 @@ +from __future__ import annotations +import argparse +import json +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ... import vault_extensions +from ...display import bcolors +from ..pam.router_helper import router_get_connected_gateways +from discovery_common.jobs import Jobs +from discovery_common.infrastructure import Infrastructure +from discovery_common.constants import DIS_INFRA_GRAPH_ID +from discovery_common.types import DiscoveryDelta, DiscoveryObject +from keeper_dag.dag import DAG +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from discovery_common.jobs import JobItem + + +def _h(text): + return f"{bcolors.HEADER}{text}{bcolors.ENDC}" + + +def _f(text): + return f"{bcolors.FAIL}{text}{bcolors.ENDC}" + + +def _g(text): + return f"{bcolors.OKGREEN}{text}{bcolors.ENDC}" + + +def _b(text): + return f"{bcolors.OKBLUE}{text}{bcolors.ENDC}" + + +class PAMGatewayActionDiscoverJobStatusCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-status-command') + parser.add_argument('--gateway', '-g', required=False, dest='gateway', action='store', + help='Show only discovery jobs from a specific gateway.') + parser.add_argument('--job-id', '-j', required=False, dest='job_id', action='store', + help='Detailed information for a specific discovery job.') + # parser.add_argument('--file', required=False, dest='json_file', action='store', + # help='Save status to JSON file.') + parser.add_argument('--history', required=False, dest='show_history', action='store_true', + help='Show history') + + def get_parser(self): + return PAMGatewayActionDiscoverJobStatusCommand.parser + + def job_detail(self, job): + pass + + @staticmethod + def print_job_table(jobs, max_gateway_name, show_history=False): + + print("") + print(f"{bcolors.HEADER}{'Job ID'.ljust(14, ' ')} " + f"{'Gateway Name'.ljust(max_gateway_name, ' ')} " + f"{'Gateway UID'.ljust(22, ' ')} " + f"{'Status'.ljust(12, ' ')} " + f"{'Resource UID'.ljust(22, ' ')} " + f"{'Started'.ljust(19, ' ')} " + f"{'Completed'.ljust(19, ' ')} " + f"{'Duration'.ljust(19, ' ')} " + f"{bcolors.ENDC}") + + print(f"{''.ljust(14, '=')} " + f"{''.ljust(max_gateway_name, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(12, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(19, '=')}") + + completed_jobs = [] + running_jobs = [] + failed_jobs = [] + + for job in jobs: + color = "" + job_id = job['job_id'] + if job['status'] == "COMPLETE": + color = bcolors.OKGREEN + completed_jobs.append(job_id) + elif job['status'] == "RUNNING": + color = bcolors.OKBLUE + running_jobs.append(job_id) + elif job['status'] == "FAILED": + failed_jobs.append(job_id) + color = bcolors.FAIL + print(f"{color}{job_id} " + f"{job['gateway'].ljust(max_gateway_name, ' ')} " + f"{job['gateway_uid']} " + f"{job['status'].ljust(12, ' ')} " + f"{(job.get('resource_uid') or 'NA').ljust(22, ' ')} " + f"{(job.get('start_ts_str') or 'NA').ljust(19, ' ')} " + f"{(job.get('end_ts_str') or 'NA').ljust(19, ' ')} " + f"{(job.get('duration') or 'NA').ljust(19, ' ')} " + f"{bcolors.ENDC}") + + if len(completed_jobs) > 0 and show_history is False: + print("") + if len(completed_jobs) == 1: + print(f"There is one {_g('COMPLETED')} job. To process, use the following command.") + else: + print(f"There are {len(completed_jobs)} {_g('COMPLETED')} jobs. " + "To process, use one of the the following commands.") + for job_id in completed_jobs: + print(_g(f" pam action discover process -j {job_id}")) + + if len(running_jobs) > 0 and show_history is False: + print("") + if len(running_jobs) == 1: + print(f"There is one {_b('RUNNING')} job. " + "If there is a problem, use the following command to cancel/remove the job.") + else: + print(f"There are {len(running_jobs)} {_b('RUNNING')} jobs. " + "If there is a problem, use one of the following commands to cancel/remove the job.") + for job_id in running_jobs: + print(_b(f" pam action discover remove -j {job_id}")) + + if len(failed_jobs) > 0 and show_history is False: + print("") + if len(failed_jobs) == 1: + print(f"There is one {_f('FAILED')} job. " + "If there is a problem, use the following command to get more information.") + else: + print(f"There are {len(failed_jobs)} {_f('FAILED')} jobs. " + "If there is a problem, use one of the following commands to get more information.") + for job_id in failed_jobs: + print(_f(f" pam action discover status -j {job_id}")) + print("") + if len(failed_jobs) == 1: + print(f"To remove the job, use the following command.") + else: + print(f"To remove the {_f('FAILED')} job, use one of the following commands.") + for job_id in failed_jobs: + print(_f(f" pam action discover remove -j {job_id}")) + + print("") + + @staticmethod + def print_job_detail(params, gateway_context, jobs, job_id): + + infra = Infrastructure(record=gateway_context.configuration, params=params) + + for job in jobs: + if job_id == job["job_id"]: + gateway_context = job["gateway_context"] + if job['status'] == "COMPLETE": + color = bcolors.OKGREEN + elif job['status'] == "RUNNING": + color = bcolors.OKBLUE + else: + color = bcolors.FAIL + status = f"{color}{job['status']}{bcolors.ENDC}" + + print("") + print(f"{_h('Job ID')}: {job['job_id']}") + print(f"{_h('Sync Point')}: {job['sync_point']}") + print(f"{_h('Gateway Name')}: {job['gateway']}") + print(f"{_h('Gateway UID')}: {job['gateway_uid']}") + print(f"{_h('Configuration UID')}: {gateway_context.configuration_uid}") + print(f"{_h('Status')}: {status}") + print(f"{_h('Resource UID')}: {job.get('resource_uid', 'NA')}") + print(f"{_h('Started')}: {job['start_ts_str']}") + print(f"{_h('Completed')}: {job.get('end_ts_str')}") + print(f"{_h('Duration')}: {job.get('duration')}") + + # If it failed, show the error and stacktrace. + if job['status'] == "FAILED": + print("") + print(f"{_h('Gateway Error')}:") + print(f"{color}{job['error']}{bcolors.ENDC}") + print("") + print(f"{_h('Gateway Stacktrace')}:") + print(f"{color}{job['stacktrace']}{bcolors.ENDC}") + # If it finished, show information about what was discovered. + elif job.get('end_ts') is not None: + job_item = job.get("job_item") # type: JobItem + + try: + infra.load(sync_point=0) + print("") + delta_json = job.get('delta') + if delta_json is not None: + delta = DiscoveryDelta.model_validate(delta_json) + print(f"{_h('Added')} - {len(delta.added)} count") + for item in delta.added: + vertex = infra.dag.get_vertex(item.uid) + discovery_object = DiscoveryObject.get_discovery_object(vertex) + print(f" * {discovery_object.description}") + + print("") + print(f"{_h('Changed')} - {len(delta.changed)} count") + for item in delta.changed: + vertex = infra.dag.get_vertex(item.uid) + discovery_object = DiscoveryObject.get_discovery_object(vertex) + print(f" * {discovery_object.description}") + if item.changes is None: + print(f" no changed, may be a object not added in prior discoveries.") + else: + for key, value in item.changes.items(): + print(f" - {key} = {value}") + + print("") + print(f"{_h('Deleted')} - {len(delta.deleted)} count") + for item in delta.deleted: + print(f" * discovery vertex {item.uid}") + else: + print(f"{_f('There are no available delta changes for this job.')}") + + except Exception as err: + print(f"{_f('Could not load delta from infrastructure: ' + str(err))}") + print("Fall back to raw graph.") + print("") + dag = DAG(conn=infra.conn, record=infra.record, graph_id=DIS_INFRA_GRAPH_ID) + print(dag.to_dot_raw(sync_point=job_item.sync_point, rank_dir="RL")) + + return + + print(f"{bcolors.FAIL}Cannot find the job{bcolors.ENDC}") + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + gateway_filter = kwargs.get("gateway") + job_id = kwargs.get("job_id") + show_history = kwargs.get("show_history") + + if job_id is not None: + show_history = True + + # Get all the PAM configuration records + configuration_records = list(vault_extensions.find_records(params, "pam.*Configuration")) + + # This is used to format the table. Start with a length of 12 characters for the gateway. + max_gateway_name = 12 + + all_jobs = [] + + # For each configuration/ gateway, we are going to get all jobs. + # We are going to query the gateway for any updated status. + gateway_context = None + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid(params, configuration_record.record_uid) + if gateway_context is None: + continue + + # If we are using a gateway filter, and this gateway is not the one, then go onto the next conf/gateway. + if gateway_filter is not None and gateway_context.is_gateway(gateway_filter) is False: + continue + + # If the gateway name is longer that the prior, set the max length to this gateway's name. + if len(gateway_context.gateway_name) > max_gateway_name: + max_gateway_name = len(gateway_context.gateway_name) + + jobs = Jobs(record=configuration_record, params=params) + if show_history is True: + job_list = reversed(jobs.history) + else: + job_list = [] + if jobs.current_job is not None: + job_list = [jobs.current_job] + + for job_item in job_list: + job = job_item.model_dump() + job["status"] = "RUNNING" + if job_item.start_ts is not None: + job["start_ts_str"] = job_item.start_ts_str + if job_item.end_ts is not None: + job["end_ts_str"] = job_item.end_ts_str + job["status"] = "COMPLETE" + + job["duration"] = job_item.duration_sec_str + + job["gateway"] = gateway_context.gateway_name + job["gateway_uid"] = gateway_context.gateway_uid + + # This is needs for details + job["gateway_context"] = gateway_context + job["job_item"] = job_item + + if job_item.success is False: + job["status"] = "FAILED" + + all_jobs.append(job) + + # Instead of printing a table, save a json file. + if kwargs.get("json_file") is not None: + with open(kwargs.get("json_file"), "w") as fh: + fh.write(json.dumps(all_jobs, indent=4)) + fh.close() + return + + if len(all_jobs) == 0: + print(f"{bcolors.FAIL}There are no discovery jobs. Use 'pam action discover start' to start a " + f"discovery job.{bcolors.ENDC}") + return + + if job_id is not None and gateway_context is not None: + self.print_job_detail(params, gateway_context, all_jobs, job_id) + else: + self.print_job_table(all_jobs, max_gateway_name, show_history) diff --git a/keepercommander/commands/discover/result_get.py b/keepercommander/commands/discover/result_get.py new file mode 100644 index 000000000..4c4ee4fc2 --- /dev/null +++ b/keepercommander/commands/discover/result_get.py @@ -0,0 +1,56 @@ +from __future__ import annotations +import argparse +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ... import vault_extensions +from ...display import bcolors +from ..pam.router_helper import router_get_connected_gateways +from discovery_common.jobs import Jobs +from discovery_common.infrastructure import Infrastructure +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from discovery_common.jobs import JobItem + + +class PAMGatewayActionDiscoverResultGetCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-command-process') + parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', + help='Discovery job id.') + parser.add_argument('--file', required=True, dest='filename', action='store', + help='Save results to file.') + + def get_parser(self): + return PAMGatewayActionDiscoverResultGetCommand.parser + + def execute(self, params, **kwargs): + + job_id = kwargs.get("job_id") + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + configuration_records = list(vault_extensions.find_records(params, "pam.*Configuration")) + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid(params, configuration_record.record_uid) + if gateway_context is None: + continue + + jobs = Jobs(record=configuration_record, params=params) + job_item = jobs.get_job(job_id) # type: JobItem + if job_item is None: + continue + + if job_item.end_ts is None: + print(f'{bcolors.FAIL}Discovery job is currently running. Cannot get results.{bcolors.ENDC}') + return + if job_item.success is False: + print(f'{bcolors.FAIL}Discovery job failed. Cannot get results.{bcolors.ENDC}') + return + + # TODO - Make a way to serialize the discovery into a form + infra = Infrastructure(record=configuration_record, params=params) + + return + + print(f'{bcolors.FAIL}Discovery job not found. Cannot get results.{bcolors.ENDC}') diff --git a/keepercommander/commands/discover/result_process.py b/keepercommander/commands/discover/result_process.py new file mode 100644 index 000000000..903d67e1f --- /dev/null +++ b/keepercommander/commands/discover/result_process.py @@ -0,0 +1,1351 @@ +from __future__ import annotations +import logging +import argparse +import json +import os.path + +from keeper_secrets_manager_core.utils import url_safe_str_to_bytes +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ..pam.router_helper import router_get_connected_gateways, router_set_record_rotation_information +from ... import api, subfolder, utils, crypto, vault, vault_extensions +from ...display import bcolors +from ...proto import router_pb2, record_pb2 +from discovery_common.jobs import Jobs +from discovery_common.process import Process, QuitException, NoDiscoveryDataException +from discovery_common.types import (DiscoveryObject, UserAcl, PromptActionEnum, PromptResult, + BulkRecordAdd, BulkRecordConvert, BulkProcessResults, BulkRecordSuccess, + BulkRecordFail, DirectoryInfo, NormalizedRecord, RecordField) +from pydantic import BaseModel +from typing import Optional, List, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + from ...vault import TypedRecord, KeeperRecord + from keeper_dag.vertex import DAGVertex + from discovery_common.record_link import RecordLink + + +def _h(value: str) -> str: + return f"{bcolors.HEADER}{value}{bcolors.ENDC}" + + +def _b(value: str) -> str: + return f"{bcolors.BOLD}{value}{bcolors.ENDC}" + + +def _f(value: str) -> str: + return f"{bcolors.FAIL}{value}{bcolors.ENDC}" + + +def _ok(value: str) -> str: + return f"{bcolors.OKGREEN}{value}{bcolors.ENDC}" + + +# This is used for the admin user search +class AdminSearchResult(BaseModel): + record: Any + is_directory_user: bool + is_pam_user: bool + + +class PAMGatewayActionDiscoverResultProcessCommand(PAMGatewayActionDiscoverCommandBase): + + """ + Process the discovery data + """ + + parser = argparse.ArgumentParser(prog='dr-discover-command-process') + parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', + help='Discovery job to process.') + + # This is not ready yet. + # parser.add_argument('--smart-add', required=False, dest='smart_add', action='store_true', + # help='Automatically add resources with credentials and their users.') + + parser.add_argument('--add-all', required=False, dest='add_all', action='store_true', + help='Respond with ADD for all prompts.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + EDITABLE = [ + "login", + "password", + "distinguishedName", + "alternativeIPs", + "database", + "privatePEMKey", + "connectDatabase", + "operatingSystem" + ] + + def get_parser(self): + return PAMGatewayActionDiscoverResultProcessCommand.parser + + @staticmethod + def _is_directory_user(record_type: str) -> bool: + # pamAzureConfiguration has tenant users what are like a directory. + return (record_type == "pamDirectory" or + record_type == "pamAzureConfiguration") + + @staticmethod + def _get_shared_folder(params: KeeperParams, pad: str, gateway_context: GatewayContext) -> str: + while True: + shared_folders = gateway_context.get_shared_folders(params) + index = 0 + for folder in shared_folders: + print(f"{pad}* {_h(str(index+1))} - {folder.get('uid')} {folder.get('name')}") + index += 1 + selected = input(f"{pad}Enter number of the shared folder>") + try: + return shared_folders[int(selected) - 1].get("uid") + except ValueError: + print(f"{pad}{_f('Input was not a number.')}") + + @staticmethod + def get_field_values(record: TypedRecord, field_type: str) -> List[str]: + return next( + (f.value + for f in record.fields + if f.type == field_type), + None + ) + + def get_keys_by_record(self, params: KeeperParams, gateway_context: GatewayContext, + record: TypedRecord) -> List[str]: + """ + For the record, get the values of fields that are key for this record type. + + :param params: + :param gateway_context: + :param record: + :return: + """ + + key_field = Process.get_key_field(record.record_type) + keys = [] + if key_field == "host_port": + values = self.get_field_values(record, "pamHostname") + if len(values) == 0: + return [] + + host = values[0].get("hostName") + port = values[0].get("port") + if port is not None: + if host is not None: + keys.append(f"{host}:{port}".lower()) + + elif key_field == "host": + values = self.get_field_values(record, "pamHostname") + if len(values) == 0: + return [] + + host = values[0].get("hostName") + if host is not None: + keys.append(host.lower()) + + elif key_field == "user": + + # This is user protobuf values. + # We could make this also use record linking if we stop using protobuf. + + record_rotation = params.record_rotation_cache.get(record.record_uid) + if record_rotation is not None: + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None or controller_uid != gateway_context.configuration_uid: + return [] + + resource_uid = record_rotation.get("resource_uid") + # If the resource uid is None, the Admin Cred Record has not been set. + if resource_uid is None: + return [] + + values = self.get_field_values(record, "login") + if len(values) == 0: + return [] + + keys.append(f"{resource_uid}:{values[0]}".lower()) + + return keys + + @staticmethod + def _record_lookup(record_uid: str, context: Optional[Any] = None) -> Optional[NormalizedRecord]: + + """ + Get the record from the Vault, normalize it, and return it. + + Since common code is using this method we want to flatten/abstract the KeeperRecord/TypedRecord. + """ + + params = context.get("params") + record = vault.TypedRecord.load(params, record_uid) # type: Optional[TypedRecord] + if record is None: + return None + + normalized_record = NormalizedRecord( + record_uid=record.record_uid, + record_type=record.record_type, + title=record.title, + notes=record.notes + ) + for field in record.fields: + normalized_record.fields.append( + RecordField( + type=field.type, + label=field.label, + value=field.value, + ) + ) + if record.custom is not None: + for field in record.custom: + normalized_record.fields.append( + RecordField( + type=field.type, + label=field.label, + value=field.value, + ) + ) + return normalized_record + + def _build_record_cache(self, params: KeeperParams, gateway_context: GatewayContext) -> dict: + + """ + Make a lookup cache for all the records. + + This is used to flag discovered items as existing if the record has already been added. This is used to + prevent duplicate records being added. + """ + + logging.debug(f"building the PAM record cache") + + # Make a cache of existing record by the criteria per record type + cache = { + "pamUser": {}, + "pamMachine": {}, + "pamDirectory": {}, + "pamDatabase": {} + } + + # Set all the PAM Records + records = list(vault_extensions.find_records(params, "pam*")) + for record in records: + # If the record type is not part of the cache, skip the record + if record.record_type not in cache: + continue + + # Load the full record + record = vault.TypedRecord.load(params, record.record_uid) # type: Optional[TypedRecord] + + cache_keys = self.get_keys_by_record( + params=params, + gateway_context=gateway_context, + record=record + ) + if len(cache_keys) == 0: + continue + + for cache_key in cache_keys: + cache[record.record_type][cache_key] = record.record_uid + + return cache + + def _edit_record(self, content: DiscoveryObject, pad: str, editable: List[str]) -> bool: + + edit_label = input(f"{pad}Enter 'title' or the name of the {_ok('Label')} to edit, RETURN to cancel> ") + + # Just pressing return exits the edit + if edit_label == "": + return False + + # If the "title" is entered, then edit the title of the record. + if edit_label.lower() == "title": + new_title = input(f"{pad}Enter new title> ") + content.title = new_title + + # If a field label is entered, and it's in the list of editable fields, then allow the user to edit. + elif edit_label in editable: + new_value = None + if edit_label in self.FIELD_MAPPING: + type_hint = self.FIELD_MAPPING[edit_label].get("type") + if type_hint == "dict": + field_input_format = self.FIELD_MAPPING[edit_label].get("field_input") + new_value = {} + for field in field_input_format: + new_value[field.get('key')] = input(f"{pad}Enter {field_input_format.get('prompt')} value> ") + elif type_hint == "csv": + new_value = input(f"{pad}Enter {edit_label} values, separate with a comma > ") + new_values = map(str.strip, new_value.split(',')) + new_value = "\n".join(new_values) + elif type_hint == "multiline": + print(_b(f"{pad}Enter multilines of text or a path, on the first line, " + "to a file that contains the value.")) + print(_b(f"{pad}To end, type 'END' at the start of a new line. You can paste text.")) + new_value = "" + first_line = True + while True: + line = input(_b(f"> ")).rstrip() + if line == "END": + break + + # If this is the first line, check if line is a path to a file. + if first_line is True: + try: + test_file = line.strip() + logging.debug(f"is first line, check for file path for '{test_file}'") + if os.path.exists(test_file) is True: + with open(test_file, "r") as fh: + new_value = fh.read() + fh.close() + break + else: + logging.debug(f"first line is not a file path") + except Exception as err: + logging.debug(f"exception checking if file: {err}") + first_line = False + new_value += line + "\n" + elif type_hint == "choice": + + values = self.FIELD_MAPPING[edit_label].get("values") + text_values = [_b(x) for x in values] + new_value = input(f"{pad}Enter one of the follow values: {', '.join(text_values)}> ") + new_value = new_value.strip().lower() + if new_value not in values: + print(f"{pad}{_f('The value ' + new_value + ' is not one of the values allowed.')}") + return False + else: + new_value = input(f"{pad}Enter new value, or path to a file that contains the value > ") + + # Is the value a path to a file, i.e., a private key file. + try: + if os.path.exists(new_value) is True: + with open(new_value, "r") as fh: + new_value = fh.read() + fh.close() + except (Exception,): + pass + + for edit_field in content.fields: + if edit_field.label == edit_label: + edit_field.value = [new_value] + + # Else, the label they entered cannot be edited. + else: + print(f"{pad}{_f('The field is not editable.')}") + return False + + return True + + @staticmethod + def _add_all_preprocess(vertex: DAGVertex, content: DiscoveryObject, parent_vertex: DAGVertex, + acl: Optional[UserAcl] = None) -> Optional[PromptResult]: + """ + This is client side check if we should skip prompting the user. + + The checks are + * A directory with the same domain already has a record. + + """ + + _ = vertex + _ = acl + + # Check if the directory for a domain exists. + # From the parent, find any directory objects. + # If they already have a record UID, don't prompt about this one. + # Once a directory for the domain exists, the user should not be prompted about this domain anymore. + if content.record_type == "pamDirectory": + for v in parent_vertex.has_vertices(): + other_content = DiscoveryObject.get_discovery_object(v) + if other_content.record_uid is not None and other_content.name == content.name: + return PromptResult(action=PromptActionEnum.SKIP) + return None + + def _prompt_display_fields(self, content: DiscoveryObject, pad: str) -> List[str]: + + editable = [] + for field in content.fields: + has_editable = False + if field.label in PAMGatewayActionDiscoverResultProcessCommand.EDITABLE: + editable.append(field.label) + has_editable = True + value = field.value + + # If there is a value, and it's not just [], also make sure the + if len(value) > 0 and value[0] is not None: + # PAM records will have only 1 item in the value array. + value = value[0] + if field.label in self.FIELD_MAPPING: + type_hint = self.FIELD_MAPPING[field.label].get("type") + formatted_value = [] + if type_hint == "dict": + field_input_format = self.FIELD_MAPPING[field.label].get("field_format") + for format_field in field_input_format: + formatted_value.append(f"{format_field.get('label')}: " + f"{value.get(format_field.get('key'))}") + elif type_hint == "csv": + formatted_value.append(", ".join(value.split("\n"))) + elif type_hint == "multiline": + formatted_value.append(value) + elif type_hint == "choice": + formatted_value.append(value) + value = ", ".join(formatted_value) + else: + if has_editable is True: + value = f"{bcolors.FAIL}MISSING{bcolors.ENDC}" + else: + value = f"{bcolors.OKBLUE}None{bcolors.ENDC}" + + color = bcolors.HEADER + if has_editable is True: + color = bcolors.OKGREEN + + rows = str(value).split("\n") + if len(rows) > 1: + value = rows[0] + _b(f"... {len(rows)} rows.") + + print(f"{pad} " + f"{color}Label:{bcolors.ENDC} {field.label}, " + f"{_h('Type:')} {field.type}, " + f"{_h('Value:')} {value}") + + if len(content.notes) > 0: + print("") + for note in content.notes: + print(f"{pad}* {note}") + + return editable + + @staticmethod + def _prompt_display_relationships(vertex: DAGVertex, content: DiscoveryObject, pad: str): + + if vertex is None: + return + + if content.record_type == "pamUser": + belongs_to = [] + for v in vertex.belongs_to_vertices(): + resource_content = DiscoveryObject.get_discovery_object(v) + belongs_to.append(resource_content.name) + count = len(belongs_to) + print("") + print(f"{pad}This user is found on {count} resource{'s' if count > 1 else ''}") + + def _prompt(self, + content: DiscoveryObject, + acl: UserAcl, + vertex: Optional[DAGVertex] = None, + parent_vertex: Optional[DAGVertex] = None, + resource_has_admin: bool = True, + item_count: int = 0, + items_left: int = 0, + indent: int = 0, + context: Optional[Any] = None) -> PromptResult: + + if context is None: + raise Exception("Context not set for processing the discovery results") + + parent_content = DiscoveryObject.get_discovery_object(parent_vertex) + + print("") + + params = context.get("params") + gateway_context = context.get("gateway_context") + dry_run = context.get("dry_run", False) + add_all = context.get("add_all") + + # If auto add is True, there are sometime we don't want to add the object. + # If we get a result, we want to return it. + # Skip the prompt. + if add_all is True and vertex is not None: + result = self._add_all_preprocess(vertex, content, parent_vertex, acl) + if result is not None: + return result + + # If the record type is a pamUser, then include parent description. + if content.record_type == "pamUser" and parent_vertex is not None: + parent_pad = "" + if indent - 1 > 0: + parent_pad = "".ljust(2 * indent, ' ') + + print(f"{parent_pad}{_h(parent_content.description)}") + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + + print(f"{pad}{_h(content.description)}") + + show_current_object = True + while show_current_object is True: + print(f"{pad}{bcolors.HEADER}Record Title:{bcolors.ENDC} {content.title}") + + logging.debug(f"Fields: {content.fields}") + + # Display the fields and return a list of fields are editable. + editable = self._prompt_display_fields(content=content, pad=pad) + if vertex is not None: + self._prompt_display_relationships(vertex=vertex, content=content, pad=pad) + + while True: + + shared_folder_uid = content.shared_folder_uid + if shared_folder_uid is None: + shared_folder_uid = gateway_context.default_shared_folder_uid + + count_prompt = "" + if item_count > 0: + count_prompt = f"{bcolors.HEADER}[{item_count - items_left + 1}/{item_count}]{bcolors.ENDC}" + edit_add_prompt = f"{count_prompt} " + if len(editable) > 0: + edit_add_prompt += f"({_b('E')})dit, " + + shared_folders = gateway_context.get_shared_folders(params) + if dry_run is False: + if len(shared_folders) > 1: + folder_name = next((x['name'] + for x in shared_folders + if x['uid'] == shared_folder_uid), + None) + edit_add_prompt += f"({_b('A')})dd to {folder_name}, "\ + f"Add to ({_b('F')})older, " + else: + if dry_run is False: + edit_add_prompt += f"({_b('A')})dd, " + prompt = f"{edit_add_prompt}({_b('S')})kip, ({_b('I')})gnore, ({_b('Q')})uit" + + command = "a" + if add_all is False: + command = input(f"{pad}{prompt}> ").lower() + if (command == "a" or command == "f") and dry_run is False: + + print(f"{pad}{bcolors.OKGREEN}Adding record to save queue.{bcolors.ENDC}") + print("") + + if command == "f": + shared_folder_uid = self._get_shared_folder(params, pad, gateway_context) + + content.shared_folder_uid = shared_folder_uid + + # This happens when the record is a pamUser and parent resource record does not have an + # administrator. + # It's like the reverse of creating an admin after adding the resource. + # It would make this user the admin for the parent resource. + # This condition would be really rare, since to get the users, the resource would have to have an + # admin user. + if content.record_type == "pamUser" and resource_has_admin is False: + + print(_b(f"{parent_content.description} does not have an administrator.")) + if (hasattr(parent_content.item, "admin_reason") and + parent_content.item.admin_reason is not None): + print("") + print(parent_content.item.admin_reason) + print("") + + while True: + + yn = input("Do you want to make this user the administrator? [Y/N]> ").lower() + if yn == "": + continue + if yn[0] == "n": + break + if yn[0] == "y": + acl.is_admin = True + break + + return PromptResult( + action=PromptActionEnum.ADD, + acl=acl, + content=content + ) + + elif command == "e" and dry_run is False: + self._edit_record(content, pad, editable) + break + + elif command == "i": + + print(f"{pad}{bcolors.OKBLUE}Creating an ignore rule for record.{bcolors.ENDC}") + return PromptResult( + action=PromptActionEnum.IGNORE, + acl=acl, + content=content + ) + + elif command == "s": + print(f"{pad}{bcolors.OKBLUE}Skipping record.{bcolors.ENDC}") + + return PromptResult( + action=PromptActionEnum.SKIP, + acl=acl, + content=content + ) + elif command == "q": + raise QuitException() + print() + + def _find_user_record(self, params: KeeperParams, context: Optional[Any] = None) -> Optional[TypedRecord]: + + gateway_context = context.get("gateway_context") # type: GatewayContext + record_link = context.get("record_link") # type: RecordLink + + # Get the latest records + params.sync_data = True + + # Make a list of all records in the shared folders. + # We will use this to check if a selected user is in the shared folders. + shared_record_uids = [] + for shared_folder in gateway_context.get_shared_folders(params): + folder = shared_folder.get("folder") + if "records" in folder: + for record in folder["records"]: + shared_record_uids.append(record.get("record_uid")) + + logging.debug(f"shared folders record uid {shared_record_uids}") + + while True: + user_search = input("Enter an user to search for [ENTER/RETURN to quit]> ") + if user_search == "": + return None + + # Search for record with the search string. + # Currently, this only works with TypedRecord, version 3. + user_record = list(vault_extensions.find_records( + params, + search_str=user_search, + record_version=3 + )) + if len(user_record) == 0: + print(f"{bcolors.FAIL}Could not find any record.{bcolors.ENDC}") + + # Find usable admin records. + admin_search_results = [] # type: List[AdminSearchResult] + for record in user_record: + + user_record = vault.KeeperRecord.load(params, record.record_uid) + if user_record.record_type == "pamUser": + + # Does the record exist in the gateway shared folder? + # We want to filter our other gateway's pamUser, or it will get overwhelming. + if user_record.record_uid not in shared_record_uids: + logging.debug(f"pamUser {record.title}, {user_record.record_uid} not in shared " + "folder, skip") + continue + + # # If a pamUser, make sure the user is part of our configuration + # record_rotation = params.record_rotation_cache.get(record.record_uid) + # if record_rotation is not None: + # configuration_uid = record_rotation.get("configuration_uid") + # if configuration_uid is None or configuration_uid == "": + # logging.debug(f"pamUser {record.title}, {record.record_uid} does not have a controller, " + # "skip") + # continue + # if configuration_uid != gateway_context.configuration_uid: + # logging.debug(f"pamUser {record.title}, {record.record_uid} controller is not this " + # " controller, skip") + # continue + # else: + # logging.debug(f"pamUser {record.title}, {record.record_uid} does not have a rotation + # settings.") + + # If the record does not exist in the record linking, it's orphaned; accept it + # If it does exist, then check if it belonged to a directory. + # Very unlikely a user that belongs to a database or another machine can be used. + + record_vertex = record_link.get_record_link(user_record.record_uid) + is_directory_user = False + if record_vertex is not None: + parent_record_uid = record_link.get_parent_record_uid(user_record.record_uid) + parent_record = vault.TypedRecord.load(params, parent_record_uid) # type: Optional[TypedRecord] + if parent_record is not None: + is_directory_user = self._is_directory_user(parent_record.record_type) + if is_directory_user is False: + logging.debug(f"pamUser parent for {user_record.title}, " + "{user_record.record_uid} is not a directory, skip") + continue + + else: + logging.debug(f"pamUser {user_record.title}, {user_record.record_uid} does not have record " + "linking vertex.") + else: + logging.debug(f"pamUser {user_record.title}, {user_record.record_uid} does not have record " + "linking vertex.") + + admin_search_results.append( + AdminSearchResult( + record=user_record, + is_directory_user=is_directory_user, + is_pam_user=True + ) + ) + + # Else this is a non-PAM record. + # Make sure it has a login, password, private key + else: + logging.debug(f"{record.record_uid} is not a pamUser") + login_field = next((x for x in record.fields if x.type == "login"), None) + password_field = next((x for x in record.fields if x.type == "password"), None) + private_key_field = next((x for x in record.fields if x.type == "keyPair"), None) + + if login_field is not None and (password_field is not None or private_key_field is not None): + admin_search_results.append( + AdminSearchResult( + record=record, + is_directory_user=False, + is_pam_user=False + ) + ) + else: + logging.debug(f"{record.title} is missing full credentials, skip") + + user_index = 1 + + admin_search_results = sorted(admin_search_results, + key=lambda x: x.is_pam_user, + reverse=True) + + has_local_user = False + for admin_search_result in admin_search_results: + is_local_user = False + if admin_search_result.record.record_type != "pamUser": + has_local_user = True + is_local_user = True + + print(f"{bcolors.HEADER}[{user_index}] {bcolors.ENDC}" + f"{_b('* ') if is_local_user is True else ''}" + f"{admin_search_result.record.title} " + f'{"(Directory User)" if admin_search_result.is_directory_user is True else ""}') + user_index += 1 + + if has_local_user is True: + print(f"{bcolors.BOLD}* Not a PAM User record. " + f"A PAM User would be generated from this record.{bcolors.ENDC}") + + select = input("Enter line number of user record to use, enter/return to refind the search, " + f"or {_b('Q')} to quit search. > ").lower() + if select == "": + continue + elif select[0] == "q": + return None + else: + try: + return admin_search_results[int(select) - 1].record # type: TypedRecord + except IndexError: + print(f"{bcolors.FAIL}Entered row index does not exists.{bcolors.ENDC}") + continue + + @staticmethod + def _handle_admin_record_from_record(record: TypedRecord, content: DiscoveryObject, context: Optional[Any] = None) \ + -> Optional[PromptResult]: + + params = context.get("param") # type: KeeperParams + gateway_context = context.get("gateway_context") # type: GatewayContext + + # Is this a pamUser record? + # Return the record UID and set its ACL to be the admin. + if record.record_type == "pamUser": + return PromptResult( + action=PromptActionEnum.ADD, + acl=UserAcl(is_admin=True), + record_uid=record.record_uid, + ) + + # If we are here, this was not a pamUser + # We need to duplicate the record. + # But confirm first + + # Get fields from the old record. + # Copy them into the fields. + login_field = next((x for x in record.fields if x.type == "login"), None) + password_field = next((x for x in record.fields if x.type == "password"), None) + private_key_field = next((x for x in record.fields if x.type == "keyPair"), None) + + content.set_field_value("login", login_field.value) + if password_field is not None: + content.set_field_value("password", password_field.value) + if private_key_field is not None: + value = private_key_field.value + if value is not None and len(value) > 0: + value = value[0] + private_key = value.get("privateKey") + if private_key is not None: + content.set_field_value("private_key", private_key) + + # Check if we have more than one shared folder. + # If we have one, confirm about adding the user. + # If multiple shared folders, allow user to select which one. + shared_folders = gateway_context.get_shared_folders(params) + if len(shared_folders) == 0: + while True: + yn = input(f"Create a PAM User record from {record.title}? [Y/N]> ").lower() + if yn == "": + continue + elif yn[0] == "n": + return None + elif yn[0] == "y": + content.shared_folder_uid = gateway_context.default_shared_folder_uid + else: + folder_name = next((x['name'] + for x in shared_folders + if x['uid'] == gateway_context.default_shared_folder_uid), + None) + while True: + shared_folders = gateway_context.get_shared_folders(params) + if len(shared_folders) > 1: + afq = input(f"({_b('A')})dd user to {folder_name}, " + f"Add user to ({_b('F')})older, " + f"({_b('Q')})uit > ").lower() + else: + afq = input(f"({_b('A')})dd user, " + f"({_b('Q')})uit > ").lower() + + if afq == "": + continue + if afq[0] == "a": + content.shared_folder_uid = gateway_context.default_shared_folder_uid + break + elif afq[0] == "f": + shared_folder_uid = PAMGatewayActionDiscoverResultProcessCommand._get_shared_folder( + params, "", gateway_context) + if shared_folder_uid is not None: + content.shared_folder_uid = shared_folder_uid + break + + return PromptResult( + action=PromptActionEnum.ADD, + acl=UserAcl(is_admin=True), + content=content, + note=f"This record replaces record {record.title} ({record.record_uid}). " + "The password on that record will not be rotated." + ) + + def _prompt_admin(self, parent_vertex: DAGVertex, content: DiscoveryObject, acl: UserAcl, + indent: int = 0, context: Optional[Any] = None) -> PromptResult: + + if content is None: + raise Exception("The admin content was not passed in to prompt the user.") + + params = context.get("params") + + parent_content = DiscoveryObject.get_discovery_object(parent_vertex) + + print("") + while True: + + print(f"{bcolors.BOLD}{parent_content.description} does not have an administrator user.{bcolors.ENDC}") + if hasattr(parent_content.item, "admin_reason") is True and parent_content.item.admin_reason is not None: + print("") + print(parent_content.item.admin_reason) + print("") + + action = input("Would you like to " + f"({_b('A')})dd new administrator user, " + f"({_b('F')})ind an existing admin, or " + f"({_b('S')})kip add? > ").lower() + + if action == "": + continue + + if action[0] == 'a': + prompt_result = self._prompt( + vertex=None, + parent_vertex=parent_vertex, + content=content, + acl=acl, + context=context, + indent=indent + 2 + ) + login = content.get_field_value("login") + if login is None or login == "": + print("") + print(f"{bcolors.FAIL}A value is needed for the login field.{bcolors.ENDC}") + continue + + print(f"{bcolors.OKGREEN}Adding admin record to save queue.{bcolors.ENDC}") + return prompt_result + elif action[0] == 'f': + print("") + record = self._find_user_record(params, context=context) + if record is not None: + admin_prompt_result = self._handle_admin_record_from_record( + record=record, + content=content, + context=context + ) + if admin_prompt_result is not None: + if admin_prompt_result.action == PromptActionEnum.ADD: + print(f"{bcolors.OKGREEN}Adding admin record to save queue.{bcolors.ENDC}") + return admin_prompt_result + elif action[0] == 's': + return PromptResult( + action=PromptActionEnum.SKIP + ) + print("") + + @staticmethod + def _display_auto_add_results(bulk_add_records: List[BulkRecordAdd]): + + """ + Display the number of record created from rule engine ADD results and smart add function. + """ + + add_count = len(bulk_add_records) + if add_count > 0: + print("") + print(f"{bcolors.OKGREEN}From the rules, automatically queued {add_count} " + f"record{'' if add_count == 1 else 's'} to be added.{bcolors.ENDC}") + + @staticmethod + def _prompt_confirm_add(bulk_add_records: List[BulkRecordAdd]): + + """ + If we quit, we want to ask the user if they want to add record for discovery objects that they selected + for addition. + """ + + print("") + count = len(bulk_add_records) + if count == 1: + msg = (f"{bcolors.BOLD}There is 1 record queued to be added to your vault. " + f"Do you wish to add it? [Y/N]> {bcolors.ENDC}") + else: + msg = (f"{bcolors.BOLD}There are {count} records queue to be added to your vault. " + f"Do you wish to add them? [Y/N]> {bcolors.ENDC}") + while True: + yn = input(msg).lower() + if yn == "": + continue + if yn[0] == "y": + return True + elif yn[0] == "n": + return False + print(f"{bcolors.FAIL}Did not get 'Y' or 'N'{bcolors.ENDC}") + + @staticmethod + def _prepare_record(content: DiscoveryObject, context: Optional[Any] = None) -> (Any, str): + + """ + Prepare the Vault record side. + + It's not created here. + It will be created at the end of the processing run in bulk. + We to build a record to get a record UID. + + :params content: The discovery object instance. + :params context: Optionally, it will contain information set from the run() method. + :returns: Returns an unsaved Keeper record instance. + """ + + params = context.get("params") + + # DEFINE V3 RECORD + + # Create an instance of a vault record to structure the data + record = vault.TypedRecord() + record.type_name = content.record_type + record.record_uid = utils.generate_uid() + record.record_key = utils.generate_aes_key() + record.title = content.title + for field in content.fields: + field_args = { + "field_type": field.type, + "field_value": field.value + } + if field.type != field.label: + field_args["field_label"] = field.label + record_field = vault.TypedField.new_field(**field_args) + record_field.required = field.required + record.fields.append(record_field) + + folder = params.folder_cache.get(content.shared_folder_uid) + folder_key = None # type: Optional[bytes] + if isinstance(folder, subfolder.SharedFolderFolderNode): + shared_folder_uid = folder.shared_folder_uid + elif isinstance(folder, subfolder.SharedFolderNode): + shared_folder_uid = folder.uid + else: + shared_folder_uid = None + if shared_folder_uid and shared_folder_uid in params.shared_folder_cache: + shared_folder = params.shared_folder_cache.get(shared_folder_uid) + folder_key = shared_folder.get('shared_folder_key_unencrypted') + + # DEFINE PROTOBUF FOR RECORD + + record_add_protobuf = record_pb2.RecordAdd() + record_add_protobuf.record_uid = utils.base64_url_decode(record.record_uid) + record_add_protobuf.record_key = crypto.encrypt_aes_v2(record.record_key, params.data_key) + record_add_protobuf.client_modified_time = utils.current_milli_time() + record_add_protobuf.folder_type = record_pb2.user_folder + if folder: + record_add_protobuf.folder_uid = utils.base64_url_decode(folder.uid) + if folder.type == 'shared_folder': + record_add_protobuf.folder_type = record_pb2.shared_folder + elif folder.type == 'shared_folder_folder': + record_add_protobuf.folder_type = record_pb2.shared_folder_folder + if folder_key: + record_add_protobuf.folder_key = crypto.encrypt_aes_v2(record.record_key, folder_key) + + data = vault_extensions.extract_typed_record_data(record) + json_data = api.get_record_data_json_bytes(data) + record_add_protobuf.data = crypto.encrypt_aes_v2(json_data, record.record_key) + + if params.enterprise_ec_key: + audit_data = vault_extensions.extract_audit_data(record) + if audit_data: + record_add_protobuf.audit.version = 0 + record_add_protobuf.audit.data = crypto.encrypt_ec( + json.dumps(audit_data).encode('utf-8'), params.enterprise_ec_key) + + return record_add_protobuf, record.record_uid + + @classmethod + def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optional[Any] = None) -> ( + BulkProcessResults): + + if len(bulk_add_records) == 1: + print("Adding the record to the Vault ...") + else: + print(f"Adding {len(bulk_add_records)} records to the Vault ...") + + params = context.get("params") + gateway_context = context.get("gateway_context") + + build_process_results = BulkProcessResults() + + # STEP 1 - Batch add new records + + # Generate a list of RecordAdd instance. + # In BulkRecordAdd they will be the record instance. + record_add_list = [r.record for r in bulk_add_records] # type: List[record_pb2.RecordAdd] + + records_per_request = 999 + + add_results = [] # type: List[record_pb2.RecordModifyResult] + logging.debug("adding record in batches") + while record_add_list: + logging.debug(f"* adding batch") + rq = record_pb2.RecordsAddRequest() + rq.client_time = utils.current_milli_time() + rq.records.extend(record_add_list[:records_per_request]) + record_add_list = record_add_list[records_per_request:] + rs = api.communicate_rest(params, rq, 'vault/records_add', rs_type=record_pb2.RecordsModifyResponse) + add_results.extend(rs.records) + + logging.debug(f"add_result: {add_results}") + + if len(add_results) != len(bulk_add_records): + logging.debug(f"attempted to batch add {len(bulk_add_records)} record(s), " + f"only have {len(add_results)} results.") + + # STEP 3 - Add rotation settings. + # Use the list we passed in, find the results, and add if the additions were successful. + + # For the records passed in to be created. + for bulk_record in bulk_add_records: + # Grab the type Keeper record instance, and title from that record. + pb_add_record = bulk_record.record + title = bulk_record.title + + rotation_disabled = False + + # Find the result for this record. + result = None + for x in add_results: + logging.debug(f"{pb_add_record.record_uid} vs {x.record_uid}") + if pb_add_record.record_uid == x.record_uid: + result = x + break + + # If we didn't get a result, then don't add the rotation settings. + if result is None: + build_process_results.failure.append( + BulkRecordFail( + title=title, + error="No status on addition to Vault. Cannot determine if added or not." + ) + ) + logging.debug(f"Did not get a result when adding record {title}") + continue + + # Check if addition failed. If it did fail, don't add the rotation settings. + success = (result.status == record_pb2.RecordModifyResult.DESCRIPTOR.values_by_name['RS_SUCCESS'].number) + status = record_pb2.RecordModifyResult.DESCRIPTOR.values_by_number[result.status].name + + if success is False: + build_process_results.failure.append( + BulkRecordFail( + title=title, + error=status + ) + ) + logging.debug(f"Had problem adding record for {title}: {status}") + continue + + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = url_safe_str_to_bytes(bulk_record.record_uid) + rq.revision = 0 + + # Set the gateway/configuration that this record should be connected. + rq.configurationUid = url_safe_str_to_bytes(gateway_context.configuration_uid) + + # Only set the resource if the record type is a PAM User. + # Machines, databases, and directories have a login/password in the record that indicates who the admin is. + if bulk_record.record_type == "pamUser": + rq.resourceUid = url_safe_str_to_bytes(bulk_record.parent_record_uid) + + # Right now, the schedule and password complexity are not set. This would be part of a rule engine. + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = rotation_disabled + + router_set_record_rotation_information(params, rq) + + build_process_results.success.append( + BulkRecordSuccess( + title=title, + record_uid=bulk_record.record_uid + ) + ) + + params.sync_data = True + + return build_process_results + + @classmethod + def _convert_records(cls, bulk_convert_records: List[BulkRecordConvert], context: Optional[Any] = None): + + params = context.get("params") + gateway_context = context.get("gateway_context") + + for bulk_convert_record in bulk_convert_records: + + record = vault.KeeperRecord.load(params, bulk_convert_record.record_uid) + + rotation_disabled = False + + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = url_safe_str_to_bytes(bulk_convert_record.record_uid) + record_rotation_revision = params.record_rotation_cache.get(bulk_convert_record.record_uid) + rq.revision = record_rotation_revision.get('revision') if record_rotation_revision else 0 + + # Set the gateway/configuration that this record should be connected. + rq.configurationUid = url_safe_str_to_bytes(gateway_context.configuration_uid) + + # Only set the resource if the record type is a PAM User. + # Machines, databases, and directories have a login/password in the record that indicates who the admin is. + if record.record_type == "pamUser": + rq.resourceUid = url_safe_str_to_bytes(bulk_convert_record.parent_record_uid) + else: + rq.resourceUid = None + + # Right now, the schedule and password complexity are not set. This would be part of a rule engine. + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = rotation_disabled + + router_set_record_rotation_information(params, rq) + + params.sync_data = True + + @staticmethod + def _get_directory_info(domain: str, + skip_users: bool = False, + context: Optional[Any] = None) -> Optional[DirectoryInfo]: + """ + Get information about this record from the vault records. + + """ + + params = context.get("params") + gateway_context = context.get("gateway_context") + + directory_info = DirectoryInfo() + + # Find the all directory records, in for this gateway, that have a domain that matches what we are looking for. + for directory_record in vault_extensions.find_records(params, record_type="pamDirectory"): + directory_record = vault.TypedRecord.load(params, + directory_record.record_uid) # type: Optional[TypedRecord] + + info = params.record_rotation_cache.get(directory_record.record_uid) + if info is None: + continue + + # Make sure this user is part of this gateway. + if info.get("configuration_uid") != gateway_context.configuration_uid: + continue + + domain_field = directory_record.get_typed_field("text", label="domainName") + if len(domain_field.value) == 0 or domain_field.value[0] == "": + continue + + if domain_field.value[0].lower() != domain.lower(): + continue + + directory_info.directory_record_uids.append(directory_record.record_uid) + + if directory_info.has_directories is True and skip_users is False: + + for user_record in vault_extensions.find_records(params, record_type="pamUser"): + info = params.record_rotation_cache.get(user_record.record_uid) + if info is None: + continue + + if info.get("resource_uid") is None or info.get("resource_uid") == "": + continue + + # If the user's belongs to a directory, and add it to the directory user list. + if info.get("resource_uid") in info.directory_record_uids: + directory_info.directory_user_record_uids.append(user_record.record_uid) + + return directory_info + + @staticmethod + def remove_job(params: KeeperParams, configuration_record: KeeperRecord, job_id: str): + + try: + jobs = Jobs(record=configuration_record, params=params) + jobs.cancel(job_id) + print(f"{bcolors.OKGREEN}No items left to process. Removing completed discovery job.{bcolors.ENDC}") + except Exception as err: + logging.error(err) + print(f"{bcolors.FAIL}No items left to process. Failed to remove discovery job.{bcolors.ENDC}") + + def execute(self, params: KeeperParams, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + job_id = kwargs.get("job_id") + add_all = kwargs.get("add_all", False) + smart_add = kwargs.get("smart_add", False) + + # Right now, keep dry_run False. We might add it back in. + dry_run = kwargs.get("dry_run", False) + debug_level = kwargs.get("debug_level", 0) + + configuration_records = list(vault_extensions.find_records(params, "pam.*Configuration")) + for configuration_record in configuration_records: + + gateway_context = GatewayContext.from_configuration_uid(params, configuration_record.record_uid) + if gateway_context is None: + continue + + record_cache = self._build_record_cache( + params=params, + gateway_context=gateway_context + ) + + # Get the current job. + # There can only be one active job. + # This will give us the sync point for the delta + jobs = Jobs(record=configuration_record, params=params, logger=logging, debug_level=debug_level) + job_item = jobs.current_job + if job_item is None: + continue + + # If this is not the job we are looking for, continue to the next gateway. + if job_item.job_id != job_id: + continue + + if job_item.end_ts is None: + print(f'{bcolors.FAIL}Discovery job is currently running. Cannot process.{bcolors.ENDC}') + return + if job_item.success is False: + print(f'{bcolors.FAIL}Discovery job failed. Cannot process.{bcolors.ENDC}') + return + + process = Process( + record=configuration_record, + job_id=job_item.job_id, + params=params, + logger=logging, + debug_level=debug_level, + ) + + if dry_run is True: + if add_all is True: + logging.debug("dry run has been set, disable auto add.") + add_all = False + + print(f"{bcolors.HEADER}The DRY RUN flag has been set. The rule engine will not add any records. " + f"You will not be prompted to edit or add records.{bcolors.ENDC}") + print("") + + if add_all is True: + print(f"{bcolors.HEADER}The ADD ALL flag has been set. All found items will be added.{bcolors.ENDC}") + print("") + + try: + results = process.run( + + # This method can get a record using the record UID + record_lookup_func=self._record_lookup, + + # Prompt user the about adding records + prompt_func=self._prompt, + + # Flag to auto add resources with credential, and all it users. + smart_add=smart_add, + + # Prompt user for an admin for a resource + prompt_admin_func=self._prompt_admin, + + # If quit, confirm if the user wants to add records + prompt_confirm_add_func=self._prompt_confirm_add, + + # Prepare records and place in queue; does not add record to vault + record_prepare_func=self._prepare_record, + + # Add record to the vault, protobuf, and record-linking graph + record_create_func=self._create_records, + + # This function will take existing pamUser record and make them belong to this + # gateway. + record_convert_func=self._convert_records, + + # A function to get directory users + directory_info_func=self._get_directory_info, + + # Pass method that will display auto added records. + auto_add_result_func=self._display_auto_add_results, + + # Provides a cache of the record key to record UID. + record_cache=record_cache, + + # Commander-specific context. + # Record link will be added by Process run as "record_link" + context={ + "params": params, + "gateway_context": gateway_context, + "dry_run": dry_run, + "add_all": add_all + } + ) + + logging.debug(f"Results: {results}") + + print("") + if results is not None and results.num_results > 0: + print(f"{bcolors.OKGREEN}Successfully added {results.success_count} " + f"record{'s' if results.success_count != 1 else ''}.{bcolors.ENDC}") + if results.has_failures is True: + print(f"{bcolors.FAIL}There were {results.failure_count} " + f"failure{'s' if results.failure_count != 1 else ''}.{bcolors.ENDC}") + for fail in results.failure: + print(f" * {fail.title}: {fail.error}") + + if process.no_items_left is True: + self.remove_job(params=params, configuration_record=configuration_record, job_id=job_id) + else: + print(f"{bcolors.FAIL}No records have been added.{bcolors.ENDC}") + + except NoDiscoveryDataException: + print(f"{bcolors.OKGREEN}All items have been added for this discovery job.{bcolors.ENDC}") + self.remove_job(params=params, configuration_record=configuration_record, job_id=job_id) + + except Exception as err: + print(f"{bcolors.FAIL}Could not process discovery: {err}{bcolors.ENDC}") + raise err + + return + + print(f"{bcolors.HEADER}Could not find the Discovery job.{bcolors.ENDC}") + print("") diff --git a/keepercommander/commands/discover/rule_add.py b/keepercommander/commands/discover/rule_add.py new file mode 100644 index 000000000..8801114c2 --- /dev/null +++ b/keepercommander/commands/discover/rule_add.py @@ -0,0 +1,114 @@ +from __future__ import annotations +import argparse +import logging +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ..pam.pam_dto import GatewayActionDiscoverRuleValidateInputs, GatewayActionDiscoverRuleValidate, GatewayAction +from ..pam.router_helper import router_send_action_to_gateway, router_get_connected_gateways +from ...display import bcolors +from ...proto import pam_pb2 +from discovery_common.rule import Rules +from discovery_common.types import ActionRuleItem +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + + +class PAMGatewayActionDiscoverRuleAddCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-rule-add') + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--action', '-a', required=True, choices=['add', 'ignore', 'prompt'], + dest='rule_action', action='store', help='Action to take if rule matches') + parser.add_argument('--priority', '-p', required=True, dest='priority', action='store', type=int, + help='Rule execute priority') + parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', + help='Ignore value case. Rule value must be in lowercase.') + parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', + action='store', help='Folder to place record.') + parser.add_argument('--statement', '-s', required=True, dest='statement', action='store', + help='Rule statement') + + def get_parser(self): + return PAMGatewayActionDiscoverRuleAddCommand.parser + + @staticmethod + def validate_rule_statement(params: KeeperParams, gateway_context: GatewayContext, statement: str): + + # Send rule the gateway to be validated. The rule is encrypted. It might contain sensitive information. + action_inputs = GatewayActionDiscoverRuleValidateInputs( + configuration_uid=gateway_context.configuration_uid, + statement=gateway_context.encrypt_str(statement) + ) + conversation_id = GatewayAction.generate_conversation_id() + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionDiscoverRuleValidate( + inputs=action_inputs, + conversation_id=conversation_id), + message_type=pam_pb2.CMT_DISCOVERY, + is_streaming=False, + destination_gateway_uid_str=gateway_context.gateway_uid + ) + + data = PAMGatewayActionDiscoverCommandBase.get_response_data(router_response) + + if data is None: + raise Exception("The router returned a failure.") + elif data.get("success") is False: + error = data.get("error") + raise Exception(f"The rule does not appear valid: {error}") + + statement_struct = data.get("statementStruct") + logging.debug(f"Rule Structure = {statement_struct}") + if isinstance(statement_struct, list) is False: + raise Exception(f"The structured rule statement is not a list.") + + return statement_struct + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + try: + gateway = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f'{bcolors.FAIL}Discovery job gateway [{gateway}] was not found.{bcolors.ENDC}') + return + + # If we are setting the shared_folder_uid, make sure it exists. + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None: + shared_folder_uids = gateway_context.get_shared_folders(params) + exists = next((x for x in shared_folder_uids if x["uid"] == shared_folder_uid), None) + if exists is None: + print(f"{bcolors.FAIL}The shared folder UID {shared_folder_uid} is not part of this " + f"application/gateway. Valid shared folder UID are:{bcolors.ENDC}") + for item in shared_folder_uids: + print(f"* {item['uid']} - {item['name']}") + return + + statement = kwargs.get("statement") + statement_struct = self.validate_rule_statement( + params=params, + gateway_context=gateway_context, + statement=statement + ) + + # If the rule passes its validation, then add control DAG + rules = Rules(record=gateway_context.configuration, params=params) + new_rule = ActionRuleItem( + action=kwargs.get("rule_action"), + priority=kwargs.get("priority"), + case_sensitive=not kwargs.get("ignore_case", False), + shared_folder_uid=kwargs.get("shared_folder_uid"), + statement=statement_struct, + enabled=True + ) + rules.add_rule(new_rule) + + print(f"{bcolors.OKGREEN}Rule has been added{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Rule was not added: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/discover/rule_list.py b/keepercommander/commands/discover/rule_list.py new file mode 100644 index 000000000..28aa8fa54 --- /dev/null +++ b/keepercommander/commands/discover/rule_list.py @@ -0,0 +1,81 @@ +from __future__ import annotations +import argparse +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ...display import bcolors +from ..pam.router_helper import router_get_connected_gateways +from discovery_common.rule import Rules +from discovery_common.types import RuleTypeEnum +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from discovery_common.types import RuleItem + + +class PAMGatewayActionDiscoverRuleListCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-rule-list') + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--search', '-s', required=False, dest='search', action='store', + help='Search for rules.') + + def get_parser(self): + return PAMGatewayActionDiscoverRuleListCommand.parser + + @staticmethod + def print_rule_table(rule_list: List[RuleItem]): + + print("") + print(f"{bcolors.HEADER}{'Rule ID'.ljust(15, ' ')} " + f"{'Action'.ljust(6, ' ')} " + f"{'Priority'.ljust(8, ' ')} " + f"{'Case'.ljust(12, ' ')} " + f"{'Added'.ljust(19, ' ')} " + f"{'Shared Folder UID'.ljust(22, ' ')} " + "Rule" + f"{bcolors.ENDC}") + + print(f"{''.ljust(15, '=')} " + f"{''.ljust(6, '=')} " + f"{''.ljust(8, '=')} " + f"{''.ljust(12, '=')} " + f"{''.ljust(19, '=')} " + f"{''.ljust(22, '=')} " + f"{''.ljust(10, '=')} ") + + for rule in rule_list: + if rule.case_sensitive is True: + ignore_case_str = "Sensitive" + else: + ignore_case_str = "Insensitive" + + shared_folder_uid = "" + if rule.shared_folder_uid is not None: + shared_folder_uid = rule.shared_folder_uid + print(f"{bcolors.OKGREEN}{rule.rule_id.ljust(14, ' ')}{bcolors.ENDC} " + f"{rule.action.value.ljust(6, ' ')} " + f"{str(rule.priority).rjust(8, ' ')} " + f"{ignore_case_str.ljust(12, ' ')} " + f"{rule.added_ts_str.ljust(19, ' ')} " + f"{shared_folder_uid.ljust(22, ' ')} " + f"{Rules.make_action_rule_statement_str(rule.statement)}") + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + gateway = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f'{bcolors.FAIL}Discovery job gateway [{gateway}] was not found.{bcolors.ENDC}') + return + + rules = Rules(record=gateway_context.configuration, params=params) + rule_list = rules.rule_list(rule_type=RuleTypeEnum.ACTION, + search=kwargs.get("search")) # type: List[RuleItem] + if len(rule_list) == 0: + print(f"{bcolors.FAIL}There are no rules. Use 'pam action discovery rule add' " + f"to create rules.{bcolors.ENDC}") + return + + self.print_rule_table(rule_list=rule_list) diff --git a/keepercommander/commands/discover/rule_remove.py b/keepercommander/commands/discover/rule_remove.py new file mode 100644 index 000000000..e2e0b21d0 --- /dev/null +++ b/keepercommander/commands/discover/rule_remove.py @@ -0,0 +1,40 @@ +import argparse +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ..pam.router_helper import router_get_connected_gateways +from ...display import bcolors +from discovery_common.rule import Rules +from discovery_common.types import RuleTypeEnum + + +class PAMGatewayActionDiscoverRuleRemoveCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-rule-list') + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID') + parser.add_argument('--rule-id', '-i', required=True, dest='rule_id', action='store', + help='Identifier for the rule') + + def get_parser(self): + return PAMGatewayActionDiscoverRuleRemoveCommand.parser + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + gateway = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f'{bcolors.FAIL}Discovery job gateway [{gateway}] was not found.{bcolors.ENDC}') + return + + try: + rule_id = kwargs.get("rule_id") + rules = Rules(record=gateway_context.configuration, params=params) + rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) + if rule_item is None: + raise ValueError("Rule Id does not exist.") + rules.remove_rule(rule_item) + + print(f"{bcolors.OKGREEN}Rule has been removed.{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Rule was not removed: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/discover/rule_update.py b/keepercommander/commands/discover/rule_update.py new file mode 100644 index 000000000..b54eafd34 --- /dev/null +++ b/keepercommander/commands/discover/rule_update.py @@ -0,0 +1,72 @@ +from __future__ import annotations +import argparse +from . import PAMGatewayActionDiscoverCommandBase, GatewayContext +from .rule_add import PAMGatewayActionDiscoverRuleAddCommand +from ..pam.router_helper import router_get_connected_gateways +from ...display import bcolors +from discovery_common.rule import Rules, RuleTypeEnum + + +class PAMGatewayActionDiscoverRuleUpdateCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-discover-rule-add') + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name of UID.') + parser.add_argument('--rule-id', '-i', required=True, dest='rule_id', action='store', + help='Identifier for the rule') + parser.add_argument('--action', '-a', required=False, choices=['add', 'ignore', 'prompt'], + dest='rule_action', action='store', help='Update the action to take if rule matches') + parser.add_argument('--priority', '-p', required=False, dest='priority', action='store', type=int, + help='Update the rule execute priority') + parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', + help='Update the rule to ignore case') + parser.add_argument('--no-ignore-case', required=False, dest='ignore_case', action='store_false', + help='Update the rule to not ignore case') + parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', + action='store', help='Update the folder to place record.') + parser.add_argument('--statement', '-s', required=False, dest='statement', action='store', + help='Update the rule statement') + + def get_parser(self): + return PAMGatewayActionDiscoverRuleUpdateCommand.parser + + def execute(self, params, **kwargs): + + if not hasattr(params, 'pam_controllers'): + router_get_connected_gateways(params) + + gateway = kwargs.get("gateway") + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f'{bcolors.FAIL}Discovery job gateway [{gateway}] was not found.{bcolors.ENDC}') + return + + try: + rule_id = kwargs.get("rule_id") + rules = Rules(record=gateway_context.configuration, params=params) + rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) + if rule_item is None: + raise ValueError("Rule Id does not exist.") + + rule_action = kwargs.get("rule_action") + if rule_action is not None: + rule_item.action = RuleTypeEnum.find_enum(rule_action) + priority = kwargs.get("priority") + if priority is not None: + rule_item.priority = priority + ignore_case = kwargs.get("ignore_case") + if ignore_case is not None: + rule_item.case_sensitive = not ignore_case + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None: + rule_item.shared_folder_uid = shared_folder_uid + statement = kwargs.get("statement") + if statement is not None: + rule_item.statement = PAMGatewayActionDiscoverRuleAddCommand.validate_rule_statement( + params=params, + gateway_context=gateway_context, + statement=statement + ) + rules.update_rule(rule_item) + print(f"{bcolors.OKGREEN}Rule has been updated{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Rule was not updated: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 12ada21c8..bf5742e46 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -22,37 +22,60 @@ from datetime import datetime from typing import Dict, Optional, Any, Set, List + import requests from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ec -from keeper_secrets_manager_core.utils import url_safe_str_to_bytes +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from keeper_secrets_manager_core.utils import url_safe_str_to_bytes, bytes_to_base64, base64_to_bytes -from .base import Command, GroupCommand, user_choice, dump_report_data, report_output_parser, field_to_title, FolderMixin +from .base import (Command, GroupCommand, user_choice, dump_report_data, report_output_parser, field_to_title, + FolderMixin) from .folder import FolderMoveCommand from .ksm import KSMCommand from .pam import gateway_helper, router_helper from .pam.config_facades import PamConfigurationRecordFacade -from .pam.config_helper import pam_configurations_get_all, pam_configuration_remove, pam_configuration_create_record_v6, record_rotation_get, \ +from .pam.config_helper import configuration_controller_get +from .pam.config_helper import pam_configurations_get_all, \ + pam_configuration_remove, pam_configuration_create_record_v6, record_rotation_get, \ pam_decrypt_configuration_data -from .pam.pam_dto import GatewayActionGatewayInfo, GatewayActionDiscoverInputs, GatewayActionDiscover, \ - GatewayActionRotate, \ - GatewayActionRotateInputs, GatewayAction, GatewayActionJobInfoInputs, \ - GatewayActionJobInfo, GatewayActionJobCancel +from .pam.pam_dto import ( + GatewayActionGatewayInfo, + GatewayActionRotate, + GatewayActionRotateInputs, GatewayAction, GatewayActionJobInfoInputs, + GatewayActionJobInfo, + GatewayActionJobCancel) + from .pam.router_helper import router_send_action_to_gateway, print_router_response, \ router_get_connected_gateways, router_set_record_rotation_information, router_get_rotation_schedules, \ get_router_url from .record_edit import RecordEditMixin -from .tunnel.port_forward.endpoint import establish_symmetric_key, WebRTCConnection, TunnelEntrance, READ_TIMEOUT, \ - find_open_port, CloseConnectionReasons +from .tunnel.port_forward.endpoint import WebRTCConnection, TunnelEntrance, READ_TIMEOUT, \ + find_open_port, CloseConnectionReasons, SOCKS5Server, TunnelDAG, get_config_uid, MAIN_NONCE_LENGTH, \ + SYMMETRIC_KEY_LENGTH, get_keeper_tokens from .. import api, utils, vault_extensions, crypto, vault, record_management, attachment, record_facades from ..display import bcolors from ..error import CommandError, KeeperApiError from ..params import KeeperParams, LAST_RECORD_UID from ..proto import pam_pb2, router_pb2, record_pb2 -from ..proto.APIRequest_pb2 import GetKsmPublicKeysRequest, GetKsmPublicKeysResponse from ..subfolder import find_parent_top_folder, try_resolve_path, BaseFolderNode from ..vault import TypedField +from .discover.job_start import PAMGatewayActionDiscoverJobStartCommand +from .discover.job_status import PAMGatewayActionDiscoverJobStatusCommand +from .discover.job_remove import PAMGatewayActionDiscoverJobRemoveCommand +from .discover.result_process import PAMGatewayActionDiscoverResultProcessCommand +from .discover.rule_add import PAMGatewayActionDiscoverRuleAddCommand +from .discover.rule_list import PAMGatewayActionDiscoverRuleListCommand +from .discover.rule_remove import PAMGatewayActionDiscoverRuleRemoveCommand +from .discover.rule_update import PAMGatewayActionDiscoverRuleUpdateCommand +from .pam_debug.acl import PAMDebugACLCommand +from .pam_debug.alter import PAMDebugAlterCommand +from .pam_debug.graph import PAMDebugGraphCommand +from .pam_debug.info import PAMDebugInfoCommand +from .pam_debug.verify import PAMDebugVerifyCommand +from .pam_debug.version import PAMDebugVersionCommand +from .pam_debug.gateway import PAMDebugGatewayCommand def register_commands(commands): @@ -94,9 +117,8 @@ def __init__(self): self.register_command('list', PAMTunnelListCommand(), 'List all Tunnels', 'l') self.register_command('stop', PAMTunnelStopCommand(), 'Stop Tunnel to the server', 'x') self.register_command('tail', PAMTunnelTailCommand(), 'View Tunnel Log', 't') - self.register_command('disable', PAMTunnelDisableCommand(), 'Disable Tunnel', 'd') - self.register_command('enable', PAMTunnelEnableCommand(), 'Enable Tunnel', 'e') - # self.default_verb = 'list' + self.register_command('edit', PAMTunnelEditCommand(), 'Edit Tunnel settings', 'e') + self.default_verb = 'list' class PAMConfigurationsCommand(GroupCommand): @@ -115,29 +137,67 @@ class PAMRotationCommand(GroupCommand): def __init__(self): super(PAMRotationCommand, self).__init__() - self.register_command('set', PAMCreateRecordRotationCommand(), 'Sets Record Rotation configuration', 'new') + self.register_command('edit', PAMCreateRecordRotationCommand(), 'Edits Record Rotation configuration', 'new') self.register_command('list', PAMListRecordRotationCommand(), 'List Record Rotation configuration', 'l') self.register_command('info', PAMRouterGetRotationInfo(), 'Get Rotation Info', 'i') self.register_command('script', PAMRouterScriptCommand(), 'Add, delete, or edit script field') self.default_verb = 'list' +class PAMDiscoveryCommand(GroupCommand): + + def __init__(self): + super(PAMDiscoveryCommand, self).__init__() + self.register_command('start', PAMGatewayActionDiscoverJobStartCommand(), 'Start a discovery process', 's') + self.register_command('status', PAMGatewayActionDiscoverJobStatusCommand(), 'Status of discovery jobs', 'st') + self.register_command('remove', PAMGatewayActionDiscoverJobRemoveCommand(), 'Cancel or remove of discovery jobs', 'r') + self.register_command('process', PAMGatewayActionDiscoverResultProcessCommand(), 'Process discovered items', 'p') + self.register_command('rule', PAMDiscoveryRuleCommand(), 'Manage discovery rules') + + self.default_verb = 'status' + + +class PAMDiscoveryRuleCommand(GroupCommand): + + def __init__(self): + super(PAMDiscoveryRuleCommand, self).__init__() + self.register_command('add', PAMGatewayActionDiscoverRuleAddCommand(), 'Add a rule', 'a') + self.register_command('list', PAMGatewayActionDiscoverRuleListCommand(), 'List all rules', 'l') + self.register_command('remove', PAMGatewayActionDiscoverRuleRemoveCommand(), 'Remove a rule', 'r') + self.register_command('update', PAMGatewayActionDiscoverRuleUpdateCommand(), 'Update a rule', 'u') + self.default_verb = 'list' + + class GatewayActionCommand(GroupCommand): def __init__(self): super(GatewayActionCommand, self).__init__() self.register_command('gateway-info', PAMGatewayActionServerInfoCommand(), 'Info command', 'i') - self.register_command('unreleased-discover', PAMGatewayActionDiscoverCommand(), 'Discover command') + self.register_command('discover', PAMDiscoveryCommand(), 'Discover command', 'd') self.register_command('rotate', PAMGatewayActionRotateCommand(), 'Rotate command', 'r') self.register_command('job-info', PAMGatewayActionJobCommand(), 'View Job details', 'ji') self.register_command('job-cancel', PAMGatewayActionJobCommand(), 'View Job details', 'jc') + self.register_command('debug', PAMDebugCommand(), 'PAM debug information') # self.register_command('job-list', DRCmdListJobs(), 'List Running jobs') +class PAMDebugCommand(GroupCommand): + + def __init__(self): + super(PAMDebugCommand, self).__init__() + self.register_command('info', PAMDebugInfoCommand(), 'Debug a record', 'i') + self.register_command('gateway', PAMDebugGatewayCommand(), 'Debug a getway', 'g') + self.register_command('graph', PAMDebugGraphCommand(), 'Render graphs', 'r') + self.register_command('verify', PAMDebugVerifyCommand(), 'Verify graphs', 'v') + self.register_command('alter', PAMDebugAlterCommand(), 'Alter graph information', 'a') + self.register_command('acl', PAMDebugACLCommand(), 'Control ACL of PAM Users', 'c') + self.register_command('version', PAMDebugVersionCommand(), 'Version modules versions') + + class PAMCmdListJobs(Command): parser = argparse.ArgumentParser(prog='pam action job-list') - parser.add_argument('--jobId', '-j', required=False, dest='job_id', action='store', help='ID of the Job running') + parser.add_argument('--jobId', '-j', required=False, dest='job_id', action='store', help='ID of the Job running') def get_parser(self): return PAMCmdListJobs.parser @@ -163,21 +223,35 @@ def execute(self, params, **kwargs): class PAMCreateRecordRotationCommand(Command): - parser = argparse.ArgumentParser(prog='pam rotation set') + parser = argparse.ArgumentParser(prog='pam rotation edit') record_group = parser.add_mutually_exclusive_group(required=True) - record_group.add_argument('--record', dest='record_name', action='store', help='Record UID, name, or pattern to be rotated manually or via schedule') - record_group.add_argument('--folder', dest='folder_name', action='store', help='Folder UID or name that holds records to be rotated manually or via schedule') + record_group.add_argument('--record', '-r', dest='record_name', action='store', + help='Record UID, name, or pattern to be rotated manually or via schedule') + record_group.add_argument('--folder', '-fd', dest='folder_name', action='store', + help='Used for bulk rotation setup. The folder UID or name that holds records to be ' + 'configured') parser.add_argument('--force', '-f', dest='force', action='store_true', help='Do not ask for confirmation') - parser.add_argument('--config', dest='config_uid', action='store', help='UID of the PAM Configuration') - parser.add_argument('--resource', dest='resource_uid', action='store', help='UID of the resource record.') + parser.add_argument('--config', '-c', required=False, dest='config_uid', action='store', + help='UID of the configuration record.') + parser.add_argument('--iam-aad-config', '-iac', dest='iam_aad_config_uid', action='store', + help='UID of a PAM Configuration. Used for an IAM or Azure AD user in place of --resource.') + parser.add_argument('--resource', '-rs', dest='resource_uid', action='store', help='UID of the resource record.') schedule_group = parser.add_mutually_exclusive_group() - schedule_group.add_argument('--schedulejson', '-sj', required=False, dest='schedule_json_data', action='append', help='Json of the scheduler. Example: -sj \'{"type": "WEEKLY", "utcTime": "15:44", "weekday": "SUNDAY", "intervalCount": 1}\'') - schedule_group.add_argument('--schedulecron', '-sc', required=False, dest='schedule_cron_data', action='append', help='Cron tab string of the scheduler. Example: to run job daily at 5:56PM UTC enter following cron -sc "56 17 * * *"') - schedule_group.add_argument('--on-demand', '-sm', required=False, dest='on_demand', action='store_true', help='Schedule On Demand') - parser.add_argument('--complexity', '-x', required=False, dest='pwd_complexity', action='store', help='Password complexity: length, upper, lower, digits, symbols. Ex. 32,5,5,5,5') + schedule_group.add_argument('--schedulejson', '-sj', required=False, dest='schedule_json_data', + action='append', help='Json of the scheduler. Example: -sj \'{"type": "WEEKLY", ' + '"utcTime": "15:44", "weekday": "SUNDAY", "intervalCount": 1}\'') + schedule_group.add_argument('--schedulecron', '-sc', required=False, dest='schedule_cron_data', + action='append', help='Cron tab string of the scheduler. Example: to run job daily at ' + '5:56PM UTC enter following cron -sc "56 17 * * *"') + schedule_group.add_argument('--on-demand', '-od', required=False, dest='on_demand', + action='store_true', help='Schedule On Demand') + parser.add_argument('--complexity', '-x', required=False, dest='pwd_complexity', action='store', + help='Password complexity: length, upper, lower, digits, symbols. Ex. 32,5,5,5,5') + parser.add_argument('--admin-user', '-a', required=False, dest='admin', action='store', + help='UID for the PAMUser record to configure the admin credential on the PAM Resource as the Admin when rotating') state_group = parser.add_mutually_exclusive_group() - state_group.add_argument('--enable', dest='enable', action='store_true', help='Enable rotation') - state_group.add_argument('--disable', dest='disable', action='store_true', help='Disable rotation') + state_group.add_argument('--enable', '-e', dest='enable', action='store_true', help='Enable rotation') + state_group.add_argument('--disable', '-d', dest='disable', action='store_true', help='Disable rotation') def get_parser(self): return PAMCreateRecordRotationCommand.parser @@ -188,6 +262,8 @@ def execute(self, params, **kwargs): folder_uids = set() record_pattern = '' record_name = kwargs.get('record_name') + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) if record_name: if record_name in params.record_cache: record_uids.add(record_name) @@ -227,6 +303,9 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None else: logging.warning('Folder \"%s\" not found. Skipping.', folder_name) + if record_name and folder_name: + raise CommandError('', 'Cannot use both --record and --folder at the same time.') + if folder_uids: regex = re.compile(fnmatch.translate(record_pattern), re.IGNORECASE).match if record_pattern else None for folder_uid in folder_uids: @@ -245,7 +324,7 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None record_uids.add(record_uid) pam_records = [] # type: List[vault.TypedRecord] - valid_record_types = {'pamDatabase', 'pamDirectory', 'pamMachine', 'pamUser'} + valid_record_types = ['pamDatabase', 'pamDirectory', 'pamMachine', 'pamUser', 'pamRemoteBrowser'] for record_uid in record_uids: record = vault.KeeperRecord.load(params, record_uid) if record and isinstance(record, vault.TypedRecord) and record.record_type in valid_record_types: @@ -301,46 +380,94 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None pwd_complexity_rule_list = {} resource_uid = kwargs.get('resource_uid') - if isinstance(resource_uid, str) and len(resource_uid) > 0: - if pam_config is None: - raise CommandError('', '"--resource" parameter requires "--config" parameter to be set as well.') - resource_field = pam_config.get_typed_field('pamResources') - if resource_field and isinstance(resource_field.value, list) and len(resource_field.value) > 0: - resources = resource_field.value[0] - if isinstance(resources, dict): - resource_uids = resources.get('resourceRef') - if isinstance(resource_uids, list): - if resource_uid not in resource_uids: - raise CommandError('', f'PAM Configuration "{pam_config.record_uid}" does not have admin credential for UID "{resource_uid}"') - else: - raise CommandError('', f'PAM Configuration "{pam_config.record_uid}'" does not have admin credentials") skipped_header = ['record_uid', 'record_title', 'problem', 'description'] skipped_records = [] valid_header = ['record_uid', 'record_title', 'enabled', 'configuration_uid', 'resource_uid', 'schedule', 'complexity'] valid_records = [] - requests = [] # type: List[router_pb2.RouterRecordRotationRequest] - for record in pam_records: - current_record_rotation = params.record_rotation_cache.get(record.record_uid) + r_requests = [] # type: List[router_pb2.RouterRecordRotationRequest] + + def config_resource(_dag, target_record, target_config_uid): + if not _dag.linking_dag.has_graph: + # Add DAG for resource + if target_config_uid: + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_config_uid) + _dag.edit_tunneling_config(rotation=True) + else: + raise CommandError('', f'{bcolors.FAIL}Resource "{target_record.record_uid}" is not associated ' + f'with any configuration. ' + f'{bcolors.OKBLUE}pam rotation edit -rs {target_record.record_uid} ' + f'--config CONFIG_UID{bcolors.ENDC}') + resource_dag = None + if not _dag.resource_belongs_to_config(target_record.record_uid): + # Change DAG to this new configuration. + resource_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, + target_record.record_uid) + _dag.link_resource_to_config(target_record.record_uid) + + admin = kwargs.get('admin') + if admin and target_record.record_type != 'pamRemoteBrowser': + _dag.link_user_to_resource(admin, target_record.record_uid, is_admin=True) + + _rotation_enabled = True if kwargs.get('enable') else False if kwargs.get('disable') else None + + if _rotation_enabled is not None: + _dag.set_resource_allowed(target_record, rotation=_rotation_enabled, + allowed_settings_name="rotation") + + if resource_dag is not None and resource_dag.linking_dag.has_graph: + # TODO: Make sure this doesn't remove everything from the new dag too + resource_dag.remove_from_dag(target_record.record_uid) + + _dag.print_tunneling_config(target_record .record_uid, config_uid=target_config_uid) + + def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): + if _dag and not _dag.linking_dag.has_graph: + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_iam_aad_config_uid) + if not _dag or not _dag.linking_dag.has_graph: + _dag.edit_tunneling_config(rotation=True) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_record.record_uid) + if old_dag.linking_dag.has_graph and old_dag.record.record_uid != target_iam_aad_config_uid: + old_dag.remove_from_dag(target_record.record_uid) + + # with IAM users the user is at the level the resource is usually at, + if not _dag.user_belongs_to_config(target_record.record_uid): + old_resource_uid = _dag.get_resource_uid(target_record.record_uid) + if old_resource_uid is not None: + print( + f'{bcolors.WARNING}User "{target_record.record_uid}" is associated with another resource: ' + f'{old_resource_uid}. ' + f'Now moving it to {target_iam_aad_config_uid} and it will no longer be rotated on {old_resource_uid}.' + f'{bcolors.ENDC}') + if old_resource_uid == _dag.record.record_uid: + _dag.unlink_user_from_resource(target_record.record_uid) + _dag.link_user_to_resource(target_record.record_uid, old_resource_uid, belongs_to=False) + _dag.link_user_to_config(target_record.record_uid) + + + current_record_rotation = params.record_rotation_cache.get(target_record.record_uid) # 1. PAM Configuration UID - record_config_uid = config_uid + record_config_uid = _dag.record.record_uid record_pam_config = pam_config if not record_config_uid: if current_record_rotation: record_config_uid = current_record_rotation.get('configuration_uid') pc = vault.KeeperRecord.load(params, record_config_uid) if pc is None: - skipped_records.append([record.record_uid, record.title, 'PAM Configuration was deleted', 'Specify a configuration UID parameter [--config]']) - continue + skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration was deleted', + 'Specify a configuration UID parameter [--config]']) + return if not isinstance(pc, vault.TypedRecord) or pc.version != 6: - skipped_records.append([record.record_uid, record.title, 'PAM Configuration is invalid', 'Specify a configuration UID parameter [--config]']) - continue + skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration is invalid', + 'Specify a configuration UID parameter [--config]']) + return record_pam_config = pc else: - skipped_records.append([record.record_uid, record.title, 'No current PAM Configuration', 'Specify a configuration UID parameter [--config]']) - continue + skipped_records.append([target_record.record_uid, target_record.title, 'No current PAM Configuration', + 'Specify a configuration UID parameter [--config]']) + return # 2. Schedule record_schedule_data = schedule_data @@ -366,12 +493,13 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None pwd_complexity_rule_list_encrypted = b'' else: if len(pwd_complexity_rule_list) > 0: - pwd_complexity_rule_list_encrypted = router_helper.encrypt_pwd_complexity(pwd_complexity_rule_list, record.record_key) + pwd_complexity_rule_list_encrypted = router_helper.encrypt_pwd_complexity(pwd_complexity_rule_list, + record.record_key) else: pwd_complexity_rule_list_encrypted = b'' # 4. Resource record - record_resource_uid = resource_uid + record_resource_uid = iam_aad_config_uid if record_resource_uid is None: if current_record_rotation: record_resource_uid = current_record_rotation.get('resourceUid') @@ -385,15 +513,192 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None if len(resource_uids) == 1: record_resource_uid = resource_uids[0] else: - skipped_records.append([record.record_uid, record.title, f'PAM Configuration: {len(resource_uids)} admin resources', + skipped_records.append([target_record.record_uid, target_record.title, + f'PAM Configuration: {len(resource_uids)} admin resources', 'Specify both configuration UID and resource UID [--config, --resource]']) - continue + return + + disabled = False + # 5. Enable rotation + if kwargs.get('enable'): + _dag.set_resource_allowed(iam_aad_config_uid, rotation=True, is_config=bool(target_iam_aad_config_uid)) + elif kwargs.get('disable'): + _dag.set_resource_allowed(iam_aad_config_uid, rotation=False, is_config=bool(target_iam_aad_config_uid)) + disabled = True + + schedule = 'On-Demand' + if isinstance(record_schedule_data, list) and len(record_schedule_data) > 0: + if isinstance(record_schedule_data[0], dict): + schedule = record_schedule_data[0].get('type') + complexity = '' + if pwd_complexity_rule_list_encrypted: + try: + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) + c = json.loads(decrypted_complexity.decode()) + complexity = f"{c.get('length', 0)},{c.get('caps', 0)},{c.get('lowercase', 0)},{c.get('digits', 0)},{c.get('special', 0)}" + except: + pass + valid_records.append( + [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, schedule, + complexity]) + + # 6. Construct Request object + rq = router_pb2.RouterRecordRotationRequest() + if current_record_rotation: + rq.revision = current_record_rotation.get('revision', 0) + rq.recordUid = utils.base64_url_decode(target_record.record_uid) + rq.configurationUid = utils.base64_url_decode(record_config_uid) + rq.resourceUid = utils.base64_url_decode(record_resource_uid) if record_resource_uid else b'' + rq.schedule = json.dumps(record_schedule_data) if record_schedule_data else '' + rq.pwdComplexity = pwd_complexity_rule_list_encrypted + rq.disabled = disabled + r_requests.append(rq) + + def config_user(_dag, target_record, target_resource_uid, target_config_uid=None): + + if _dag and _dag.linking_dag: + admin_record_uids = _dag.get_all_admins() + if folder_name and target_record.record_uid in admin_record_uids: + # If iterating through a folder, skip admin records + skipped_records.append([target_record.record_uid, target_record.title, 'Admin Credential', + 'This record is used as Admin credentials on a PAM Configuration. Skipped']) + return + + if isinstance(target_resource_uid, str) and len(target_resource_uid) > 0: + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + if not _dag or not _dag.linking_dag.has_graph: + if target_config_uid and target_resource_uid: + config_resource(_dag, target_record, target_config_uid) + if not _dag or not _dag.linking_dag.has_graph: + raise CommandError('', f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated ' + f'with any configuration. ' + f'{bcolors.OKBLUE}pam rotation edit -rs {target_resource_uid} ' + f'--config CONFIG_UID{bcolors.ENDC}') + + if not _dag.check_if_resource_has_admin(target_resource_uid): + raise CommandError('', f'PAM Resource "{target_resource_uid}'" does not have " + "admin credentials. Please link an admin credential to this resource. " + f"{bcolors.OKBLUE}pam rotation edit -rs {target_resource_uid} " + f"--admin-user ADMIN_UID{bcolors.ENDC}") + current_record_rotation = params.record_rotation_cache.get(target_record.record_uid) + + if not _dag or not _dag.linking_dag.has_graph: + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + if not _dag.linking_dag.has_graph: + raise CommandError('', f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated ' + f'with any configuration.' + f'{bcolors.OKBLUE}pam rotation edit -rs {target_resource_uid} ' + f'--config CONFIG_UID{bcolors.ENDC}') + if not target_resource_uid: + # Get the resource configuration from DAG + resource_uids = _dag.get_all_owners(target_record.record_uid) + if len(resource_uids) > 1: + raise CommandError('', f'{bcolors.FAIL}Record "{target_record.record_uid}" is ' + f'associated with multiple resources so you must supply ' + f'{bcolors.OKBLUE}"--resource/-rs RESOURCE_UID".{bcolors.ENDC}') + elif len(resource_uids) == 0: + raise CommandError('', + f'{bcolors.FAIL}Record "{target_record.record_uid}" is not associated with' + f' any resource. Please use {bcolors.OKBLUE}"pam rotation user ' + f'{target_record.record_uid} --resource RESOURCE_UID" {bcolors.FAIL}to associate ' + f'it.{bcolors.ENDC}') + target_resource_uid = resource_uids[0] + + if not _dag.resource_belongs_to_config(target_resource_uid): + raise CommandError('', + f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated with the ' + f'configuration of the user "{target_record.record_uid}". To associated the resources ' + f'to this config run {bcolors.OKBLUE}"pam rotation resource {target_resource_uid} ' + f'--config {_dag.record.record_uid}"{bcolors.ENDC}') + if not _dag.user_belongs_to_resource(target_record.record_uid, target_resource_uid): + old_resource_uid = _dag.get_resource_uid(target_record.record_uid) + if old_resource_uid is not None and old_resource_uid != target_resource_uid: + print( + f'{bcolors.WARNING}User "{target_record.record_uid}" is associated with another resource: ' + f'{old_resource_uid}. ' + f'Now moving it to {target_resource_uid} and it will no longer be rotated on {old_resource_uid}.' + f'{bcolors.ENDC}') + _dag.link_user_to_resource(target_record.record_uid, old_resource_uid, belongs_to=False) + _dag.link_user_to_resource(target_record.record_uid, target_resource_uid, belongs_to=True) + + # 1. PAM Configuration UID + record_config_uid = _dag.record.record_uid + record_pam_config = pam_config + if not record_config_uid: + if current_record_rotation: + record_config_uid = current_record_rotation.get('configuration_uid') + pc = vault.KeeperRecord.load(params, record_config_uid) + if pc is None: + skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration was deleted', + 'Specify a configuration UID parameter [--config]']) + return + if not isinstance(pc, vault.TypedRecord) or pc.version != 6: + skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration is invalid', + 'Specify a configuration UID parameter [--config]']) + return + record_pam_config = pc + else: + skipped_records.append([target_record.record_uid, target_record.title, 'No current PAM Configuration', + 'Specify a configuration UID parameter [--config]']) + return + + # 2. Schedule + record_schedule_data = schedule_data + if record_schedule_data is None: + if current_record_rotation: + try: + current_schedule = current_record_rotation.get('schedule') + if current_schedule: + record_schedule_data = json.loads(current_schedule) + except: + pass + elif record_pam_config: + schedule_field = record_pam_config.get_typed_field('schedule', 'defaultRotationSchedule') + if schedule_field and isinstance(schedule_field.value, list) and len(schedule_field.value) > 0: + if isinstance(schedule_field.value[0], dict): + record_schedule_data = [schedule_field.value[0]] + else: + record_schedule_data = [] + + # 3. Password complexity + if pwd_complexity_rule_list is None: + if current_record_rotation: + pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation['pwd_complexity']) + else: + pwd_complexity_rule_list_encrypted = b'' + else: + if len(pwd_complexity_rule_list) > 0: + pwd_complexity_rule_list_encrypted = router_helper.encrypt_pwd_complexity(pwd_complexity_rule_list, + record.record_key) + else: + pwd_complexity_rule_list_encrypted = b'' + + # 4. Resource record + record_resource_uid = target_resource_uid + if record_resource_uid is None: + if current_record_rotation: + record_resource_uid = current_record_rotation.get('resourceUid') + if record_resource_uid is None: + resource_field = record_pam_config.get_typed_field('pamResources') + if resource_field and isinstance(resource_field.value, list) and len(resource_field.value) > 0: + resources = resource_field.value[0] + if isinstance(resources, dict): + resource_uids = resources.get('resourceRef') + if isinstance(resource_uids, list) and len(resource_uids) > 0: + if len(resource_uids) == 1: + record_resource_uid = resource_uids[0] + else: + skipped_records.append([target_record.record_uid, target_record.title, + f'PAM Configuration: {len(resource_uids)} admin resources', + 'Specify both configuration UID and resource UID [--config, --resource]']) + return + disabled = False # 5. Enable rotation - disabled = current_record_rotation.get('disabled') if current_record_rotation else False - if kwargs.get('enable') is True: - disabled = False - elif kwargs.get('disable') is True: + if kwargs.get('enable'): + _dag.set_resource_allowed(target_resource_uid, rotation=True) + elif kwargs.get('disable'): + _dag.set_resource_allowed(target_resource_uid, rotation=False) disabled = True schedule = 'On-Demand' @@ -403,24 +708,46 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None complexity = '' if pwd_complexity_rule_list_encrypted: try: - decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, record.record_key) + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) c = json.loads(decrypted_complexity.decode()) complexity = f"{c.get('length', 0)},{c.get('caps', 0)},{c.get('lowercase', 0)},{c.get('digits', 0)},{c.get('special', 0)}" except: pass - valid_records.append([record.record_uid, record.title, not disabled, record_config_uid, record_resource_uid, schedule, complexity]) + valid_records.append( + [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, schedule, + complexity]) # 6. Construct Request object rq = router_pb2.RouterRecordRotationRequest() if current_record_rotation: - rq.revision = current_record_rotation.get('revision') - rq.recordUid = utils.base64_url_decode(record.record_uid) + rq.revision = current_record_rotation.get('revision', 0) + rq.recordUid = utils.base64_url_decode(target_record.record_uid) rq.configurationUid = utils.base64_url_decode(record_config_uid) rq.resourceUid = utils.base64_url_decode(record_resource_uid) if record_resource_uid else b'' rq.schedule = json.dumps(record_schedule_data) if record_schedule_data else '' rq.pwdComplexity = pwd_complexity_rule_list_encrypted rq.disabled = disabled - requests.append(rq) + r_requests.append(rq) + + for _record in pam_records: + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, _record.record_uid) + if _record.record_type in ['pamMachine', 'pamDatabase', 'pamDirectory', 'pamRemoteBrowser']: + config_resource(tmp_dag, _record, config_uid) + elif _record.record_type == 'pamUser': + iam_aad_config_uid = kwargs.get('iam_aad_config_uid') + + if iam_aad_config_uid and iam_aad_config_uid not in pam_configurations: + raise CommandError('', f'Record uid {iam_aad_config_uid} is not a PAM Configuration record.') + + if resource_uid and iam_aad_config_uid: + raise CommandError('', f'Cannot use both --resource and --iam-aad-config_uid at once.' + f' --resource is used to configure users found on a resource.' + f' --iam-aad-config-uid is used to configure AWS IAM or Azure AD users') + + if iam_aad_config_uid: + config_iam_aad_user(tmp_dag, _record, iam_aad_config_uid) + else: + config_user(tmp_dag, _record, resource_uid, config_uid) force = kwargs.get('force') is True @@ -428,12 +755,12 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None skipped_header = [field_to_title(x) for x in skipped_header] dump_report_data(skipped_records, skipped_header, title='The following record(s) were skipped') - if len(requests) > 0 and not force: + if len(r_requests) > 0 and not force: answer = user_choice('\nDo you want to cancel password rotation?', 'Yn', 'Y') if answer.lower().startswith('y'): return - if len(requests) > 0: + if len(r_requests) > 0: valid_header = [field_to_title(x) for x in valid_header] dump_report_data(valid_records, valid_header, title='The following record(s) will be updated') if not force: @@ -441,18 +768,21 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None if answer.lower().startswith('n'): return - for rq in requests: + for rq in r_requests: record_uid = utils.base64_url_encode(rq.recordUid) try: - router_set_record_rotation_information(params, rq) + router_set_record_rotation_information(params, rq, transmission_key, encrypted_transmission_key, + encrypted_session_token) except KeeperApiError as kae: - logging.warning('Record "%s": Set rotation error "%s": %s', record_uid, kae.result_code, kae.message) + logging.warning('Record "%s": Set rotation error "%s": %s', + record_uid, kae.result_code, kae.message) params.sync_data = True class PAMListRecordRotationCommand(Command): parser = argparse.ArgumentParser(prog='pam rotation list') - parser.add_argument('--verbose', '-v', dest='is_verbose', action='store_true', help='Verbose output') + parser.add_argument('--verbose', '-v', required=False, default=False, dest='is_verbose', action='store_true', + help='Verbose output') def get_parser(self): return PAMListRecordRotationCommand.parser @@ -495,19 +825,23 @@ def execute(self, params, **kwargs): record_uid = utils.base64_url_encode(s.recordUid) controller_uid = s.controllerUid - controller_details = next((ctr for ctr in enterprise_all_controllers if ctr.controllerUid == controller_uid), None) + controller_details = next( + (ctr for ctr in enterprise_all_controllers if ctr.controllerUid == controller_uid), None) configuration_uid = s.configurationUid configuration_uid_str = utils.base64_url_encode(configuration_uid) - pam_configuration = next((pam_config for pam_config in all_pam_config_records if pam_config.get('record_uid') == configuration_uid_str), None) + pam_configuration = next((pam_config for pam_config in all_pam_config_records if + pam_config.get('record_uid') == configuration_uid_str), None) - is_controller_online = any((poc for poc in enterprise_controllers_connected_uids_bytes if poc == controller_uid)) + is_controller_online = any( + (poc for poc in enterprise_controllers_connected_uids_bytes if poc == controller_uid)) row_color = '' if record_uid in params.record_cache: row_color = bcolors.HIGHINTENSITYWHITE rec = params.record_cache[record_uid] - data_json = rec['data_unencrypted'].decode('utf-8') if isinstance(rec['data_unencrypted'], bytes) else rec['data_unencrypted'] + data_json = rec['data_unencrypted'].decode('utf-8') if isinstance(rec['data_unencrypted'], bytes) else \ + rec['data_unencrypted'] data = json.loads(data_json) record_title = data.get('title') @@ -518,6 +852,10 @@ def execute(self, params, **kwargs): record_title = '[record inaccessible]' record_type = '[record inaccessible]' + if record_type != "pamUser": + # only pamUser records are supported for rotation + continue + row.append(f'{row_color}{record_uid}') row.append(record_title) row.append(record_type) @@ -546,17 +884,15 @@ def execute(self, params, **kwargs): enterprise_controllers_connected = router_get_connected_gateways(params) connected_controller = None if enterprise_controllers_connected and controller_details: - # Find connected controller (TODO: Optimize, don't search for controllers every time, no N^n) - router_controllers = [x.controllerUid for x in enterprise_controllers_connected.controllers] - connected_controller = next( - (x for x in router_controllers if x == controller_details.controllerUid), None) + router_controllers = {controller.controllerUid: controller for controller in + list(enterprise_controllers_connected.controllers)} + connected_controller = router_controllers.get(controller_details.controllerUid) if connected_controller: controller_stat_color = bcolors.OKGREEN else: controller_stat_color = bcolors.WHITE - controller_color = bcolors.WHITE if is_controller_online: controller_color = bcolors.OKGREEN @@ -573,7 +909,8 @@ def execute(self, params, **kwargs): if not is_verbose: row.append(f"{bcolors.FAIL}[No config found]{bcolors.ENDC}") else: - row.append(f"{bcolors.FAIL}[No config found. Looks like configuration {configuration_uid_str} was removed but rotation schedule was not modified{bcolors.ENDC}") + row.append( + f"{bcolors.FAIL}[No config found. Looks like configuration {configuration_uid_str} was removed but rotation schedule was not modified{bcolors.ENDC}") else: pam_data_decrypted = pam_decrypt_configuration_data(pam_configuration) @@ -597,8 +934,10 @@ def execute(self, params, **kwargs): class PAMGatewayListCommand(Command): parser = argparse.ArgumentParser(prog='dr-gateway') - parser.add_argument('--force', '-f', required=False, default=False, dest='is_force', action='store_true', help='Force retrieval of gateways') - parser.add_argument('--verbose', '-v', required=False, default=False, dest='is_verbose', action='store_true', help='Verbose output') + parser.add_argument('--force', '-f', required=False, default=False, dest='is_force', action='store_true', + help='Force retrieval of gateways') + parser.add_argument('--verbose', '-v', required=False, default=False, dest='is_verbose', action='store_true', + help='Verbose output') def get_parser(self): return PAMGatewayListCommand.parser @@ -656,9 +995,9 @@ def execute(self, params, **kwargs): connected_controller = None if enterprise_controllers_connected: - # Find connected controller (TODO: Optimize, don't search for controllers every time, no N^n) - router_controllers = list(enterprise_controllers_connected.controllers) - connected_controller = next((x for x in router_controllers if x.controllerUid == c.controllerUid), None) + router_controllers = {controller.controllerUid: controller for controller in + list(enterprise_controllers_connected.controllers)} + connected_controller = router_controllers.get(c.controllerUid) row_color = '' if not is_router_down: @@ -698,8 +1037,8 @@ def execute(self, params, **kwargs): if is_verbose: row.append(f'{row_color}{c.deviceName}{bcolors.ENDC}') row.append(f'{row_color}{c.deviceToken}{bcolors.ENDC}') - row.append(f'{row_color}{datetime.fromtimestamp(c.created/1000)}{bcolors.ENDC}') - row.append(f'{row_color}{datetime.fromtimestamp(c.lastModified/1000)}{bcolors.ENDC}') + row.append(f'{row_color}{datetime.fromtimestamp(c.created / 1000)}{bcolors.ENDC}') + row.append(f'{row_color}{datetime.fromtimestamp(c.lastModified / 1000)}{bcolors.ENDC}') row.append(f'{row_color}{c.nodeId}{bcolors.ENDC}') table.append(row) @@ -726,11 +1065,16 @@ def execute(self, params, **kwargs): pam_configuration_uid = kwargs.get('pam_configuration') is_verbose = kwargs.get('verbose') - if not pam_configuration_uid: # Print ALL root level configs + if not pam_configuration_uid: # Print ALL root level configs PAMConfigurationListCommand.print_root_rotation_setting(params, is_verbose) - else: # Print element configs (config that is not a root) + else: # Print element configs (config that is not a root) PAMConfigurationListCommand.print_pam_configuration_details(params, pam_configuration_uid, is_verbose) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_configuration_uid, + is_config=True) + tmp_dag.print_tunneling_config(pam_configuration_uid, None) + @staticmethod def print_pam_configuration_details(params, config_uid, is_verbose=False): configuration = vault.KeeperRecord.load(params, config_uid) @@ -778,7 +1122,7 @@ def print_root_rotation_setting(params, is_verbose=False): configurations = list(vault_extensions.find_records(params, record_version=6)) facade = PamConfigurationRecordFacade() - for c in configurations: # type: vault.TypedRecord + for c in configurations: # type: vault.TypedRecord if c.record_type in ('pamAwsConfiguration', 'pamAzureConfiguration', 'pamNetworkConfiguration'): facade.record = c shared_folder_parents = find_parent_top_folder(params, c.record_uid) @@ -799,24 +1143,24 @@ def print_root_rotation_setting(params, is_verbose=False): table.append(row) else: - logging.warning(f'Following configuration is not in the shared folder: UID: %s, Title: %s', c.record_uid, c.title) + logging.warning(f'Following configuration is not in the shared folder: UID: %s, Title: %s', + c.record_uid, c.title) else: - logging.warning(f'Following configuration has unsupported type: UID: %s, Title: %s', c.record_uid, c.title) + logging.warning(f'Following configuration has unsupported type: UID: %s, Title: %s', c.record_uid, + c.title) table.sort(key=lambda x: (x[1] or '')) dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None) common_parser = argparse.ArgumentParser(add_help=False) -common_parser.add_argument('--config-type', '-ct', dest='config_type', action='store', - choices=['network', 'aws', 'azure'], help='PAM Configuration Type', ) +common_parser.add_argument('--environment', '-env', dest='config_type', action='store', + choices=['local', 'aws', 'azure'], help='PAM Configuration Type', ) common_parser.add_argument('--title', '-t', dest='title', action='store', help='Title of the PAM Configuration') -common_parser.add_argument('--gateway', '-g', dest='gateway', action='store', help='Gateway UID or Name') -common_parser.add_argument('--shared-folder', '-sf', dest='shared_folder', action='store', +common_parser.add_argument('--gateway', '-g', dest='gateway_uid', action='store', help='Gateway UID or Name') +common_parser.add_argument('--shared-folder', '-sf', dest='shared_folder_uid', action='store', help='Share Folder where this PAM Configuration is stored. Should be one of the folders to ' 'which the gateway has access to.') -common_parser.add_argument('--resource-record', '-rr', dest='resource_records', action='append', - help='Resource Record UID') common_parser.add_argument('--schedule', '-sc', dest='default_schedule', action='store', help='Default Schedule: Use CRON syntax') common_parser.add_argument('--port-mapping', '-pm', dest='port_mapping', action='append', help='Port Mapping') network_group = common_parser.add_argument_group('network', 'Local network configuration') @@ -869,8 +1213,8 @@ def parse_pam_configuration(self, params, record, **kwargs): field.value.append(dict()) value = field.value[0] - gateway_uid = None # type: Optional[str] - gateway = kwargs.get('gateway') # type: Optional[str] + gateway_uid = None # type: Optional[str] + gateway = kwargs.get('gateway_uid') # type: Optional[str] if gateway: gateways = gateway_helper.get_all_gateways(params) gateway_uid = next((utils.base64_url_encode(x.controllerUid) for x in gateways @@ -887,24 +1231,29 @@ def parse_pam_configuration(self, params, record, **kwargs): # if len(shares) == 0: # raise Exception(f'Gateway %s has no shared folders', gateway.controllerName) - shared_folder_uid = None # type: Optional[str] - folder_name = kwargs.get('shared_folder') # type: Optional[str] + shared_folder_uid = None # type: Optional[str] + folder_name = kwargs.get('shared_folder_uid') # type: Optional[str] if folder_name: if folder_name in params.shared_folder_cache: shared_folder_uid = folder_name else: for sf_uid in params.shared_folder_cache: sf = api.get_shared_folder(params, sf_uid) - if sf: - if sf.name.casefold() == folder_name.casefold(): - shared_folder_uid = sf_uid - break + if sf and sf.name.casefold() == folder_name.casefold(): + shared_folder_uid = sf_uid + break if shared_folder_uid: value['folderUid'] = shared_folder_uid + else: + for f in record.fields: + if f.type == 'pamResources' and f.value and len(f.value) > 0 and 'folderUid' in f.value[0]: + shared_folder_uid = f.value[0]['folderUid'] + break + if not shared_folder_uid: + raise CommandError('pam config edit', 'Shared Folder not found') - rr = kwargs.get('resource_records') rrr = kwargs.get('remove_records') - if rr or rrr: + if rrr: pam_record_lookup = {} rti = PamConfigurationEditMixin.get_pam_record_types(params) for r in vault_extensions.find_records(params, record_type=rti): @@ -924,15 +1273,6 @@ def parse_pam_configuration(self, params, record, **kwargs): record_uids.remove(r_l) continue logging.warning(f'Failed to find PAM record: {r}') - if isinstance(rr, list): - for r in rr: - if r in pam_record_lookup: - record_uids.add(r) - continue - r_l = r.lower() - if r_l in pam_record_lookup: - record_uids.add(r_l) - self.warnings.append(f'Failed to find PAM record: {r}') value['resourceRef'] = list(record_uids) @@ -946,6 +1286,8 @@ def parse_properties(self, params, record, **kwargs): # type: (KeeperParams, va schedule = kwargs.get('default_schedule') if schedule: extra_properties.append(f'schedule.defaultRotationSchedule={schedule}') + else: + extra_properties.append(f'schedule.defaultRotationSchedule=On-Demand') if record.record_type == 'pamNetworkConfiguration': network_id = kwargs.get('network_id') @@ -991,15 +1333,13 @@ def parse_properties(self, params, record, **kwargs): # type: (KeeperParams, va if extra_properties: self.assign_typed_fields(record, [RecordEditMixin.parse_field(x) for x in extra_properties]) - def verify_required(self, record): # type: (vault.TypedRecord) -> None + def verify_required(self, record): # type: (vault.TypedRecord) -> None for field in record.fields: if field.required: if len(field.value) == 0: if field.type == 'schedule': field.value = [{ - 'type': 'RUN_ONCE', - 'time': '2000-01-01T00:00:00', - 'tz': 'Etc/UTC', + 'type': 'ON_DEMAND' }] else: self.warnings.append(f'Empty required field: "{field.get_field_name()}"') @@ -1010,6 +1350,16 @@ def verify_required(self, record): # type: (vault.TypedRecord) -> None class PAMConfigurationNewCommand(Command, PamConfigurationEditMixin): parser = argparse.ArgumentParser(prog='pam config new', parents=[common_parser]) + parser.add_argument('--enable-connections', '-ec', dest='enable_connections', action='store_true', + help='Enable connections') + parser.add_argument('--enable-tunneling', '-et', dest='enable_tunneling', + action='store_true', help='Enable tunneling') + parser.add_argument('--enable-rotation', '-er', dest='enable_rotation', action='store_true', + help='Enable rotation') + parser.add_argument('--enable-connections-recording', '-ecr', required=False, dest='recordingenabled', + action='store_true', help='Enable recording connections for the resource') + parser.add_argument('--enable-typescripts-recording', '-etcr', required=False, dest='typescriptrecordingenabled', + action='store_true', help='Enable TypeScript recording for the resource') def __init__(self): super().__init__() @@ -1022,13 +1372,16 @@ def execute(self, params, **kwargs): config_type = kwargs.get('config_type') if not config_type: - raise CommandError('pam-config-new', '--config-type parameter is required') + raise CommandError('pam-config-new', '--environment parameter is required') if config_type == 'aws': record_type = 'pamAwsConfiguration' elif config_type == 'azure': record_type = 'pamAzureConfiguration' - else: + elif config_type == 'local': record_type = 'pamNetworkConfiguration' + else: + raise CommandError('pam-config-new', f'--environment {config_type} is not supported' + f' supported options are aws, azure, or local') title = kwargs.get('title') if not title: @@ -1056,12 +1409,25 @@ def execute(self, params, **kwargs): shared_folder_uid = value.get('folderUid') if not shared_folder_uid: - raise CommandError('pam-config-new', '--shared_folder parameter is required to create a PAM configuration') + raise CommandError('pam-config-new', '--shared-folder parameter is required to create a PAM configuration') self.verify_required(record) pam_configuration_create_record_v6(params, record, shared_folder_uid) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + # Add DAG for configuration + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid=record.record_uid, + is_config=True) + tmp_dag.edit_tunneling_config( + bool(kwargs.get('enable_connections')), + bool(kwargs.get('enable_tunneling')), + bool(kwargs.get('enable_rotation')), + bool(kwargs.get('recordingenabled')), + bool(kwargs.get('typescriptrecordingenabled')) + ) + tmp_dag.print_tunneling_config(record.record_uid, None) + # Moving v6 record into the folder api.sync_down(params) FolderMoveCommand().execute(params, src=record.record_uid, dst=shared_folder_uid, force=True) @@ -1078,15 +1444,33 @@ def execute(self, params, **kwargs): for w in self.warnings: logging.warning(w) + params.environment_variables[LAST_RECORD_UID] = record.record_uid return record.record_uid class PAMConfigurationEditCommand(Command, PamConfigurationEditMixin): parser = argparse.ArgumentParser(prog='pam config edit', parents=[common_parser]) + parser.add_argument('uid', type=str, action='store', help='The Config UID to edit') parser.add_argument('--remove-resource-record', '-rrr', dest='remove_records', action='append', help='Resource Record UID to remove') - parser.add_argument('--config', '-c', required=True, dest='config', action='store', - help='PAM Configuration UID or Title') + parser.add_argument('--enable-rotation', '-er', required=False, action='store_true',help='Enable rotation') + parser.add_argument('--disable-rotation', '-dr', required=False, action='store_true', help='Disable rotation') + parser.add_argument('--enable-tunneling', '-et', required=False, dest='enable_tunneling', action='store_true', + help='Disable tunneling') + parser.add_argument('--disable-tunneling', '-dt', required=False, dest='disable_tunneling', action='store_true', + help='Disable tunneling') + parser.add_argument('--enable-connections', '-ec', required=False, dest='enable_connections', action='store_true', + help='Enable connections') + parser.add_argument('--disable-connections', '-dc', required=False, dest='disable_connections', action='store_true', + help='Enable connections') + parser.add_argument('--enable-connections-recording', '-ecr', required=False, dest='enable_connections_recording', + action='store_true', help='Enable connections recording') + parser.add_argument('--disable-connections-recording', '-dcr', required=False, dest='disable_connections_recording', + action='store_true', help='Disable connections recording') + parser.add_argument('--enable-typescripts-recording', '-etsr', required=False, dest='enable_typescripts_recording', + action='store_true', help='Enable typescripts recording') + parser.add_argument('--disable-typescripts-recording', '-dtsr', required=False, dest='disable_typescripts_recording', + action='store_true', help='Disable typescripts recording') def __init__(self): super(PAMConfigurationEditCommand, self).__init__() @@ -1098,7 +1482,10 @@ def execute(self, params, **kwargs): self.warnings.clear() configuration = None - config_name = kwargs.get('config') + + config_name = kwargs.get('uid') + if not config_name: + raise CommandError('pam config edit', 'PAM Configuration UID or Title is required') if config_name in params.record_cache: configuration = vault.KeeperRecord.load(params, config_name) else: @@ -1115,12 +1502,12 @@ def execute(self, params, **kwargs): config_type = kwargs.get('config_type') if config_type: if not config_type: - raise CommandError('pam-config-new', '--config-type parameter is required') + raise CommandError('pam-config-new', '--environment parameter is required') if config_type == 'aws': record_type = 'pamAwsConfiguration' elif config_type == 'azure': record_type = 'pamAzureConfiguration' - elif config_type == 'network': + elif config_type == 'local': record_type = 'pamNetworkConfiguration' else: record_type = configuration.record_type @@ -1163,22 +1550,49 @@ def execute(self, params, **kwargs): if shared_folder_uid != orig_shared_folder_uid: FolderMoveCommand().execute(params, src=configuration.record_uid, dst=shared_folder_uid) + if ((kwargs.get('enable_connections') and kwargs.get('disable_connections')) or + (kwargs.get('enable_tunneling') and kwargs.get('disable_tunneling')) or + (kwargs.get('enable_rotation') and kwargs.get('disable_rotation')) or + (kwargs.get('enable_connections_recording') and kwargs.get('disable_connections_recording')) or + (kwargs.get('enable_typescripts_recording') and kwargs.get('disable_typescripts_recording'))): + raise CommandError('pam-config-edit', 'Cannot enable and disable the same feature at the same time') + + # First check if enabled is true then check if disabled is true. if not then set it to None + _connections = True if kwargs.get('enable_connections') \ + else False if kwargs.get('disable_connections') else None + _tunneling = True if kwargs.get('enable_tunneling') else False if kwargs.get('disable_tunneling') else None + _rotation = True if kwargs.get('enable_rotation') else False if kwargs.get('disable_rotation') else None + _recording = True if kwargs.get('enable_connections_recording') \ + else False if kwargs.get('disable_connections_recording') else None + _typescript_recording = (True if kwargs.get('enable_typescripts_recording') else False if + kwargs.get('disable_typescripts_recording') else None) + + if (_connections is not None or _tunneling is not None or _rotation is not None or _recording is not None or + _typescript_recording is not None): + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, + configuration.record_uid, is_config=True) + tmp_dag.edit_tunneling_config(_connections, _tunneling, _rotation, _recording, _typescript_recording) + tmp_dag.print_tunneling_config(configuration.record_uid, None) for w in self.warnings: logging.warning(w) + params.sync_data = True class PAMConfigurationRemoveCommand(Command): parser = argparse.ArgumentParser(prog='pam config remove') - parser.add_argument('--config', '-c', required=True, dest='pam_config', action='store', - help='PAM Configuration UID. To view all rotation settings with their UIDs, ' - 'use command `pam config list`') + parser.add_argument('uid', type=str, action='store', + help='PAM Configuration UID. To view all rotation settings with their UIDs, use command ' + '`pam config list`') def get_parser(self): return PAMConfigurationRemoveCommand.parser def execute(self, params, **kwargs): - pam_config_name = kwargs.get('pam_config') + pam_config_name = kwargs.get('uid') + if not pam_config_name: + raise CommandError('pam config edit', 'PAM Configuration UID is required') pam_config_uid = None for config in vault_extensions.find_records(params, record_version=6): if config.record_uid == pam_config_name: @@ -1188,7 +1602,14 @@ def execute(self, params, **kwargs): pass if not pam_config_name: raise Exception(f'Configuration "{pam_config_name}" not found') - + pam_config = vault.KeeperRecord.load(params, pam_config_uid) + if not pam_config: + raise Exception(f'Configuration "{pam_config_uid}" not found') + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config.record_uid, + is_config=True) + if tmp_dag.linking_dag.has_graph: + tmp_dag.remove_from_dag(pam_config_uid) pam_configuration_remove(params, pam_config_uid) params.sync_data = True @@ -1217,20 +1638,31 @@ def execute(self, params, **kwargs): print(f"Gateway Name where the rotation will be performed: {bcolors.OKBLUE}{(rri.controllerName if rri.controllerName else '-')}{bcolors.ENDC}") print(f"Gateway Uid: {bcolors.OKBLUE}{(utils.base64_url_encode(rri.controllerUid) if rri.controllerUid else '-') } {bcolors.ENDC}") + + def is_resource_ok(resource_id, params, configuration_uid): + if resource_id not in params.record_cache: + return False + + configuration = vault.KeeperRecord.load(params, configuration_uid) + if not isinstance(configuration, vault.TypedRecord): + return False + + field = configuration.get_typed_field('pamResources') + if not (field and isinstance(field.value, list) and len(field.value) == 1): + return False + + rv = field.value[0] + if not isinstance(rv, dict): + return False + + resources = rv.get('resourceRef') + return isinstance(resources, list) and resource_id in resources + if rri.resourceUid: resource_id = utils.base64_url_encode(rri.resourceUid) - resource_ok = False - if resource_id in params.record_cache: - configuration = vault.KeeperRecord.load(params, configuration_uid) - if isinstance(configuration, vault.TypedRecord): - field = configuration.get_typed_field('pamResources') - if field and isinstance(field.value, list) and len(field.value) == 1: - rv = field.value[0] - if isinstance(rv, dict): - resources = rv.get('resourceRef') - if isinstance(resources, list): - resource_ok = resource_id in resources - print(f"Admin Resource Uid: {bcolors.OKBLUE if resource_ok else bcolors.FAIL}{resource_id}{bcolors.ENDC}") + resource_ok = is_resource_ok(resource_id, params, configuration_uid) + print(f"Admin Resource Uid: {bcolors.OKBLUE if resource_ok else bcolors.FAIL}{resource_id}" + f"{bcolors.ENDC}") # print(f"Router Cookie: {bcolors.OKBLUE}{(rri.cookie if rri.cookie else '-')}{bcolors.ENDC}") # print(f"scriptName: {bcolors.OKGREEN}{rri.scriptName}{bcolors.ENDC}") @@ -1256,7 +1688,7 @@ def execute(self, params, **kwargs): class PAMRouterScriptCommand(GroupCommand): def __init__(self): super().__init__() - self.register_command('list', PAMScriptListCommand(), 'List script fields') + self.register_command('list', PAMScriptListCommand(), 'List script fields') self.register_command('add', PAMScriptAddCommand(), 'List Record Rotation Schedulers') self.register_command('edit', PAMScriptEditCommand(), 'Add, delete, or edit script field') self.register_command('delete', PAMScriptDeleteCommand(), 'Delete script field') @@ -1386,7 +1818,7 @@ def execute(self, params, **kwargs): if not record_name: raise CommandError('rotate script', '"record" argument is required') - script_name = kwargs.get('script') # type: Optional[str] + script_name = kwargs.get('script') # type: Optional[str] if not script_name: raise CommandError('rotate script', '"script" argument is required') @@ -1461,7 +1893,7 @@ def execute(self, params, **kwargs): if not record_name: raise CommandError('rotate script', '"record" argument is required') - script_name = kwargs.get('script') # type: Optional[str] + script_name = kwargs.get('script') # type: Optional[str] if not script_name: raise CommandError('rotate script', '"script" argument is required') @@ -1508,7 +1940,6 @@ def get_parser(self): return PAMGatewayActionJobCancelCommand.parser def execute(self, params, **kwargs): - job_id = kwargs.get('job_id') print(f"Job id to cancel [{job_id}]") @@ -1535,7 +1966,6 @@ def get_parser(self): return PAMGatewayActionJobCommand.parser def execute(self, params, **kwargs): - job_id = kwargs.get('job_id') gateway_uid = kwargs.get('gateway_uid') @@ -1561,6 +1991,7 @@ class PAMGatewayActionRotateCommand(Command): parser = argparse.ArgumentParser(prog='dr-rotate-command') parser.add_argument('--record-uid', '-r', required=True, dest='record_uid', action='store', help='Record UID to rotate') + # parser.add_argument('--config', '-c', required=True, dest='configuration_uid', action='store', # help='Rotation configuration UID') @@ -1593,34 +2024,54 @@ def execute(self, params, **kwargs): # rule_list_json = crypto.decrypt_aes_v2(utils.base64_url_decode(ri_pwd_complexity_encrypted), record.record_key) # complexity = json.loads(rule_list_json.decode()) - ri_rotation_setting_uid = utils.base64_url_encode(ri.configurationUid) # Configuration on the UI is "Rotation Setting" - resource_uid = utils.base64_url_encode(ri.resourceUid) + resource_uid = None + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) + if not config_uid: + # Still try it the old way + # Configuration on the UI is "Rotation Setting" + ri_rotation_setting_uid = utils.base64_url_encode(ri.configurationUid) + resource_uid = utils.base64_url_encode(ri.resourceUid) + pam_config = vault.KeeperRecord.load(params, ri_rotation_setting_uid) + if not isinstance(pam_config, vault.TypedRecord): + print(f'{bcolors.FAIL}PAM Configuration [{ri_rotation_setting_uid}] is not available.{bcolors.ENDC}') + return + facade = PamConfigurationRecordFacade() + facade.record = pam_config - pam_config = vault.KeeperRecord.load(params, ri_rotation_setting_uid) - if not isinstance(pam_config, vault.TypedRecord): - print(f'{bcolors.FAIL}PAM Configuration [{ri_rotation_setting_uid}] is not available.{bcolors.ENDC}') - return - facade = PamConfigurationRecordFacade() - facade.record = pam_config + config_uid = facade.controller_uid + + if not resource_uid: + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record.record_uid) + resource_uid = tmp_dag.get_resource_uid(record_uid) + if not resource_uid: + print(f'{bcolors.FAIL}Resource UID not found for record [{record_uid}]. please configure it ' + f'{bcolors.OKBLUE}"pam rotation user {record_uid} --resource RESOURCE_UID"{bcolors.ENDC}') + return + + controller = configuration_controller_get(params, url_safe_str_to_bytes(config_uid)) + if not controller.controllerUid: + raise CommandError('', f'{bcolors.FAIL}Gateway UID not found for configuration ' + f'{config_uid}.') # Find connected controllers enterprise_controllers_connected = router_get_connected_gateways(params) + controller_from_config_bytes = controller.controllerUid + gateway_uid = utils.base64_url_encode(controller.controllerUid) if enterprise_controllers_connected: - # Find connected controller (TODO: Optimize, don't search for controllers every time, no N^n) - router_controllers = list(enterprise_controllers_connected.controllers) - controller_from_config_bytes = utils.base64_url_decode(facade.controller_uid) - connected_controller = next((x.controllerUid for x in router_controllers - if x.controllerUid == controller_from_config_bytes), None) + router_controllers = {controller.controllerUid: controller for controller in + list(enterprise_controllers_connected.controllers)} + connected_controller = router_controllers.get(controller_from_config_bytes) if not connected_controller: - print(f'{bcolors.WARNING}The Gateway "{facade.controller_uid}" is down.{bcolors.ENDC}') + print(f'{bcolors.WARNING}The Gateway "{gateway_uid}" is down.{bcolors.ENDC}') return else: print(f'{bcolors.WARNING}There are no connected gateways.{bcolors.ENDC}') return - # rrs = RouterRotationStatus.Name(ri.status) # if rrs == 'RRS_NO_ROTATION': # print(f'{bcolors.FAIL}Record [{record_uid}] does not have rotation associated with it.{bcolors.ENDC}') @@ -1642,7 +2093,7 @@ def execute(self, params, **kwargs): action_inputs = GatewayActionRotateInputs( record_uid=record_uid, - configuration_uid=ri_rotation_setting_uid, + configuration_uid=config_uid, pwd_complexity_encrypted=ri_pwd_complexity_encrypted, resource_uid=resource_uid ) @@ -1651,8 +2102,9 @@ def execute(self, params, **kwargs): router_response = router_send_action_to_gateway( params=params, gateway_action=GatewayActionRotate(inputs=action_inputs, conversation_id=conversation_id, - gateway_destination=facade.controller_uid), - message_type=pam_pb2.CMT_ROTATE, is_streaming=False) + gateway_destination=gateway_uid), + message_type=pam_pb2.CMT_ROTATE, is_streaming=False, encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) print_router_response(router_response, conversation_id) @@ -1679,33 +2131,81 @@ def execute(self, params, **kwargs): print_router_response(router_response, response_type='gateway_info', is_verbose=is_verbose) -class PAMGatewayActionDiscoverCommand(Command): - parser = argparse.ArgumentParser(prog='dr-discover-command') - parser.add_argument('--shared-folder', '-f', required=True, dest='shared_folder_uid', action='store', - help='UID of the Shared Folder where results will be stored') - parser.add_argument('--provider-record', '-p', required=True, dest='provider_record_uid', action='store', - help='Provider Record UID that defines network') - # parser.add_argument('--destinations', '-d', required=False, dest='destinations', action='store', - # help='Controller id') +class PAMGatewayActionDiscoverCommandBase(Command): - def get_parser(self): - return PAMGatewayActionDiscoverCommand.parser + """ + The discover command base. - def execute(self, params, **kwargs): + Contains static methods to get the configuration record, get and update the discovery store. These are method + used by multiple discover actions. + """ - provider_record_uid = kwargs.get('provider_record_uid') - shared_folder_uid = kwargs.get('shared_folder_uid') + # If the discovery data field does not exist, or the field contains no values, use the template to init the + # field. + STORE_VALUE_TEMPLATE = { + "ignore_list": [], + "jobs": [] + } - action_inputs = GatewayActionDiscoverInputs(shared_folder_uid, provider_record_uid) - conversation_id = GatewayAction.generate_conversation_id() + STORE_LABEL = "discoveryStore" - router_response = router_send_action_to_gateway( - params, - GatewayActionDiscover(inputs=action_inputs, conversation_id=conversation_id), - message_type=pam_pb2.CMT_GENERAL, - is_streaming=False) + @staticmethod + def get_configuration(params, configuration_uid): - print_router_response(router_response, conversation_id) + configuration_record = vault.KeeperRecord.load(params, configuration_uid) + if not isinstance(configuration_record, vault.TypedRecord): + print(f'{bcolors.FAIL}PAM Configuration [{configuration_uid}] is not available.{bcolors.ENDC}') + return + + configuration_facade = PamConfigurationRecordFacade() + configuration_facade.record = configuration_record + + return configuration_record, configuration_facade + + @staticmethod + def get_discovery_store(configuration_record): + + # Get the discovery store. It contains information about discovery job for a configuration. It is on the custom + # fields. + discovery_field = None + if configuration_record.custom is not None: + discovery_field = next((field + for field in configuration_record.custom + if field.label == PAMGatewayActionDiscoverCommandBase.STORE_LABEL), + None) + + discovery_field_exists = True + if discovery_field is None: + logging.debug("discovery store field does not exists, creating") + discovery_field = TypedField.new_field("_hidden", + [PAMGatewayActionDiscoverCommandBase.STORE_VALUE_TEMPLATE], + PAMGatewayActionDiscoverCommandBase.STORE_LABEL) + discovery_field_exists = False + else: + logging.debug("discovery store record exists") + + # The value should not be [], if it is, init with the defaults. + if len(discovery_field.value) == 0: + logging.debug("discovery store does not have a value, set to the default value") + discovery_field.value = [PAMGatewayActionDiscoverCommandBase.STORE_VALUE_TEMPLATE] + + # TODO - REMOVE ME, this is just so we have one job + # discovery_field.value = [PAMGatewayActionDiscoverCommandBase.STORE_VALUE_TEMPLATE] + + return discovery_field.value[0], discovery_field, discovery_field_exists + + @staticmethod + def update_discovery_store(params, configuration_record, discovery_store, discovery_field, discovery_field_exists): + + discovery_field.value = [discovery_store] + if discovery_field_exists is False: + if configuration_record.custom is None: + configuration_record.custom = [] + configuration_record.custom.append(discovery_field) + + # Update the record here to prevent a race-condition + record_management.update_record(params, configuration_record) + params.sync_data = True class PAMGatewayRemoveCommand(Command): @@ -1731,7 +2231,6 @@ def execute(self, params, **kwargs): class PAMCreateGatewayCommand(Command): - dr_create_controller_parser = argparse.ArgumentParser(prog='dr-create-gateway') dr_create_controller_parser.add_argument('--name', '-n', required=True, dest='gateway_name', help='Name of the Gateway', @@ -1746,7 +2245,8 @@ class PAMCreateGatewayCommand(Command): dr_create_controller_parser.add_argument('--return_value', '-r', dest='return_value', action='store_true', help='Return value from the command for automation purposes') dr_create_controller_parser.add_argument('--config-init', '-c', type=str, dest='config_init', action='store', - choices=['json', 'b64'], help='Initialize client config and return configuration string.') # json, b64, file + choices=['json', 'b64'], + help='Initialize client config and return configuration string.') # json, b64, file def get_parser(self): return PAMCreateGatewayCommand.dr_create_controller_parser @@ -1785,16 +2285,7 @@ def execute(self, params, **kwargs): print('-----------------------------------------------') - - - - - - - - - -############################################## TUNNELING ############################################################### +# TUNNELING class PAMTunnelListCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel list') @@ -1804,21 +2295,19 @@ def get_parser(self): def execute(self, params, **kwargs): def gather_tabel_row_data(thread): # {"thread": t, "host": host, "port": port, "started": datetime.now(), - row = [] + r_row = [] run_time = None hours = 0 minutes = 0 seconds = 0 entrance = thread.get('entrance') - # - # row.append(f"{thread.get('name', '')}") if entrance is not None: - row.append(f"{bcolors.OKBLUE}{entrance.pc.endpoint_name}{bcolors.ENDC}") + r_row.append(f"{bcolors.OKBLUE}{entrance.pc.endpoint_name}{bcolors.ENDC}") else: - row.append(f"{bcolors.WARNING}Connecting..{bcolors.ENDC}") + r_row.append(f"{bcolors.WARNING}Connecting..{bcolors.ENDC}") - row.append(f"{thread.get('host', '')}") + r_row.append(f"{thread.get('host', '')}") if entrance is not None and entrance.print_ready_event.is_set(): if thread.get('started'): @@ -1826,12 +2315,12 @@ def gather_tabel_row_data(thread): hours, remainder = divmod(run_time.seconds, 3600) minutes, seconds = divmod(remainder, 60) - row.append( - f"{bcolors.OKBLUE}{entrance._port}{bcolors.ENDC}" + r_row.append( + f"{bcolors.OKBLUE}{entrance.port}{bcolors.ENDC}" ) else: - row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") - row.append(f"{thread.get('record_uid', '')}") + r_row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") + r_row.append(f"{thread.get('record_uid', '')}") if entrance is not None and entrance.print_ready_event.is_set(): text_line = "" if run_time: @@ -1842,10 +2331,10 @@ def gather_tabel_row_data(thread): text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else '' text_line += f"{minutes} min " text_line += f"{seconds} sec" - row.append(text_line) + r_row.append(text_line) else: - row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") - return row + r_row.append(f"{bcolors.WARNING}Connecting...{bcolors.ENDC}") + return r_row if not params.tunnel_threads: logging.warning(f"{bcolors.OKBLUE}No Tunnels running{bcolors.ENDC}") @@ -1973,149 +2462,226 @@ class SocketNotConnectedException(Exception): pass -def retrieve_gateway_public_key(gateway_uid, params, api, utils) -> bytes: - gateway_uid_bytes = utils.base64_url_decode(gateway_uid) - get_ksm_pubkeys_rq = GetKsmPublicKeysRequest() - get_ksm_pubkeys_rq.controllerUids.append(gateway_uid_bytes) - get_ksm_pubkeys_rs = api.communicate_rest(params, get_ksm_pubkeys_rq, 'vault/get_ksm_public_keys', - rs_type=GetKsmPublicKeysResponse) - - if len(get_ksm_pubkeys_rs.keyResponses) == 0: - # No keys found - print(f"{bcolors.FAIL}No keys found for gateway {gateway_uid}{bcolors.ENDC}") - return b'' - try: - gateway_public_key_bytes = get_ksm_pubkeys_rs.keyResponses[0].publicKey - except Exception as e: - # No public key found - print(f"{bcolors.FAIL}Error getting public key for gateway {gateway_uid}: {e}{bcolors.ENDC}") - gateway_public_key_bytes = b'' - - return gateway_public_key_bytes - - -class PAMTunnelEnableCommand(Command): - pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel enable') +class PAMTunnelEditCommand(Command): + pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel edit') pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Record UID of the PAM ' 'resource record with network information to use ' 'for tunneling') pam_cmd_parser.add_argument('--configuration', '-c', required=False, dest='config', action='store', help='The PAM Configuration UID to use for tunneling. ' 'Use command `pam config list` to view available PAM Configurations.') + # pam_cmd_parser.add_argument('--enable-connections', '-ec', required=False, dest='enable_connections', action='store_true', + # help='Enable connections on the record') + pam_cmd_parser.add_argument('--enable-tunneling', '-et', required=False, dest='enable_tunneling', action='store_true', + help='Enable tunneling on the record') + pam_cmd_parser.add_argument('--tunneling-override-port', '-top', required=False, dest='tunneling_override_port', + action='store', help='Port to use for tunneling. If not provided, ' + 'the port from the record will be used.') + # pam_cmd_parser.add_argument('--connections-override-port', '-cop', required=False, dest='connections_override_port', + # action='store', help='Port to use for connections. If not provided, ' + # 'the port from the record will be used.') + # pam_cmd_parser.add_argument('--disable-connections', '-dc', required=False, dest='disable_connections', + # action='store_true', help='Disable connections on the record') + pam_cmd_parser.add_argument('--disable-tunneling', '-dt', required=False, dest='disable_tunneling', + action='store_true', help='Disable tunneling on the record') + pam_cmd_parser.add_argument('--remove-tunneling-override-port', '-rtop', required=False, + dest='remove_tunneling_override_port', action='store_true', + help='Remove tunneling override port') + # pam_cmd_parser.add_argument('--remove-connections-override-port', '-rcop', required=False, + # dest='remove_tunneling_override_port', action='store_true', + # help='Remove connections override port') + # pam_cmd_parser.add_argument('--enable-connections-recording', '-ecr', required=False, + # dest='enable_connections_recording', action='store_true', + # help='Enable connections recording') + # pam_cmd_parser.add_argument('--disable-connections-recording', '-dcr', required=False, + # dest='disable_connections_recording', action='store_true', + # help='Disable connections recording') + # pam_cmd_parser.add_argument('--enable-typescripts-recording', '-etsr', required=False, + # dest='enable_typescripts_recording', action='store_true', + # help='Enable typescripts recording') + # pam_cmd_parser.add_argument('--disable-typescripts-recording', '-dtsr', required=False, + # dest='disable_typescripts_recording', action='store_true', + # help='Disable typescripts recording') def get_parser(self): - return PAMTunnelEnableCommand.pam_cmd_parser + return PAMTunnelEditCommand.pam_cmd_parser def execute(self, params, **kwargs): record_uid = kwargs.get('uid') config_uid = kwargs.get('config') + # connection_override_port = kwargs.get('connections_override_port') + tunneling_override_port = kwargs.get('tunneling_override_port') + + if ((kwargs.get('enable_tunneling') and kwargs.get('disable_tunneling')) or + (kwargs.get('enable_rotation') and kwargs.get('disable_rotation')) or + (kwargs.get('tunneling-override-port') and kwargs.get('remove_tunneling_override_port'))): + raise CommandError('pam-config-edit', 'Cannot enable and disable the same feature at the same time') + + # if ((kwargs.get('enable_connections') and kwargs.get('disable_connections')) or + # (kwargs.get('enable_tunneling') and kwargs.get('disable_tunneling')) or + # (kwargs.get('enable_rotation') and kwargs.get('disable_rotation')) or + # (kwargs.get('enable_connections_recording') and kwargs.get('disable_connections_recording')) or + # (kwargs.get('enable_typescripts_recording') and kwargs.get('disable_typescripts_recording')) or + # (kwargs.get('tunneling-override-port') and kwargs.get('remove_tunneling_override_port')) or + # (kwargs.get('connections-override-port') and kwargs.get('remove_connections_override_port'))): + # raise CommandError('pam-config-edit', 'Cannot enable and disable the same feature at the same time') + + # First check if enabled is true then check if disabled is true. if not then set it to None + # _connections = True if kwargs.get('enable_connections') else False if kwargs.get('disable_connections') else None + _tunneling = True if kwargs.get('enable_tunneling') else False if kwargs.get('disable_tunneling') else None + # _rotation = True if kwargs.get('enable_rotation') else False if kwargs.get('disable_rotation') else None + # _recording = True if kwargs.get('enable_connections_recording') else False if kwargs.get('disable_connections_recording') else None + # _typescript_recording = (True if kwargs.get('enable_typescripts_recording') else False if + # kwargs.get('disable_typescripts_recording') else None) + _remove_tunneling_override_port = kwargs.get('remove_tunneling_override_port') + # _remove_connections_override_port = kwargs.get('remove_connections_override_port') + if not record_uid: - raise CommandError('tunnel Enable', '"record UID" argument is required') - dirty = False + raise CommandError('tunnel edit', '"record UID" argument is required') + + if tunneling_override_port: + try: + tunneling_override_port = int(tunneling_override_port) + except ValueError: + raise CommandError('tunnel edit', 'tunneling-override-port must be an integer') + # if connection_override_port: + # try: + # connection_override_port = int(connection_override_port) + # except ValueError: + # raise CommandError('tunnel edit', 'connection-override-port must be an integer') record = vault.KeeperRecord.load(params, record_uid) if not isinstance(record, vault.TypedRecord): - print(f"{bcolors.FAIL}Record {record_uid} not found.{bcolors.ENDC}") - return + raise CommandError('', f"{bcolors.FAIL}Record {record_uid} not found.{bcolors.ENDC}") record_type = record.record_type - if record_type not in "pamMachine pamDatabase pamDirectory".split(): - print(f"{bcolors.FAIL}This record's type is not supported for tunnels. " - f"Tunnels are only supported on Pam Machine, Pam Database, and Pam Directory records{bcolors.ENDC}") - return - - if config_uid: - configuration = vault.KeeperRecord.load(params, config_uid) - if not isinstance(configuration, vault.TypedRecord): - print(f"{bcolors.FAIL}Configuration {config_uid} not found.{bcolors.ENDC}") - return - if (configuration.record_type not in - 'pamNetworkConfiguration pamAwsConfiguration pamAzureConfiguration'.split()): - print(f"{bcolors.FAIL}The record {config_uid} is not a Pam Configuration.{bcolors.ENDC}") - return - - pam_settings = record.get_typed_field('pamSettings') - if not pam_settings: - pre_settings = {"portForward": {"enabled": True}} - if config_uid: - pre_settings["configUid"] = config_uid - pam_settings = vault.TypedField.new_field('pamSettings', pre_settings, "") - record.custom.append(pam_settings) - dirty = True + if record_type not in ("pamMachine pamDatabase pamDirectory pamNetworkConfiguration pamAwsConfiguration " + "pamRemoteBrowser pamAzureConfiguration").split(): + raise CommandError('', f"{bcolors.FAIL}This record's type is not supported for tunnels. " + f"Tunnels are only supported on pamMachine, pamDatabase, pamDirectory, " + f"pamRemoteBrowser, pamNetworkConfiguration pamAwsConfiguration, and " + f"pamAzureConfiguration records{bcolors.ENDC}") + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + if record_type in "pamNetworkConfiguration pamAwsConfiguration pamAzureConfiguration".split(): + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True) + tmp_dag.edit_tunneling_config(tunneling=_tunneling) + tmp_dag.print_tunneling_config(record_uid, None) else: - if config_uid: - if pam_settings.value[0].get('configUid') != config_uid: - pam_settings.value[0]['configUid'] = config_uid - dirty = True - if not pam_settings.value[0]['portForward']['enabled']: - pam_settings.value[0]['portForward']['enabled'] = True + traffic_encryption_key = record.get_typed_field('trafficEncryptionSeed') + # Generate a 256-bit (32-byte) random seed + seed = os.urandom(32) + dirty = False + if not traffic_encryption_key.value: + base64_seed = bytes_to_base64(seed) + record_seed = vault.TypedField.new_field('trafficEncryptionSeed', base64_seed, "") + record.custom.append(record_seed) dirty = True - if not pam_settings.value[0].get('configUid'): - print(f"{bcolors.FAIL}No PAM Configuration UID set. This must be set for tunneling to work. " - f"This can be done by running 'pam tunnel enable {record_uid} --config [ConfigUID]' " - f"The ConfigUID can be found by running 'pam config list'{bcolors.ENDC}") - return - - client_private_key = record.get_typed_field('trafficEncryptionKey') - if not client_private_key: - # Generate an EC private key - # TODO: maybe try to use keeper method to generate key - # private_key, _ = crypto.generate_ec_key() - # client_private_key_value = crypto.unload_ec_private_key(private_key).decode('utf-8') - private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() - ) - # Serialize to PEM format - client_private_key_value = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') - client_private_key = vault.TypedField.new_field('trafficEncryptionKey', - client_private_key_value, "") - record.custom.append(client_private_key) - dirty = True - - if dirty: - record_management.update_record(params, record) - api.sync_down(params) - if pam_settings.value[0].get('configUid'): - print(f"{bcolors.OKGREEN}Tunneling enabled for {record_uid} using configuration " - f"{pam_settings.value[0].get('configUid')} {bcolors.ENDC}") - else: - print(f"{bcolors.OKGREEN}Tunneling enabled for {record_uid}{bcolors.ENDC}") - - -class PAMTunnelDisableCommand(Command): - pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel disable') - pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Record UID of the PAM ' - 'resource record with network information to use ' - 'for tunneling') - - def get_parser(self): - return PAMTunnelDisableCommand.pam_cmd_parser + if dirty: + record_management.update_record(params, record) + api.sync_down(params) - def execute(self, params, **kwargs): + traffic_encryption_key = record.get_typed_field('trafficEncryptionSeed') + if not traffic_encryption_key: + raise CommandError('', f"{bcolors.FAIL}Unable to add Seed to record {record_uid}. " + f"Please make sure you have edit rights to record {record_uid} {bcolors.ENDC}") + dirty = False + + existing_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) + + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid) + + if config_uid and existing_config_uid != config_uid: + old_dag.remove_from_dag(record_uid) + tmp_dag.link_resource_to_config(record_uid) + + if tmp_dag is None or not tmp_dag.linking_dag.has_graph: + raise CommandError('', f"{bcolors.FAIL}No PAM Configuration UID set. " + f"This must be set or supplied for tunneling to work. This can be done by adding " + f"{bcolors.OKBLUE}' --config [ConfigUID] " + f" {bcolors.FAIL}The ConfigUID can be found by running " + f"{bcolors.OKBLUE}'pam config list'{bcolors.ENDC}") + + if not tmp_dag.check_tunneling_enabled_config(enable_tunneling=_tunneling): + tmp_dag.print_tunneling_config(config_uid, None) + command = f"{bcolors.OKBLUE}'pam tunnel edit {config_uid}" + # if _connections and not tmp_dag.check_tunneling_enabled_config(enable_connections=_connections): + # command += f" --enable-connections" if _connections else "" + if _tunneling and not tmp_dag.check_tunneling_enabled_config( + enable_tunneling=_tunneling): + command += f" --enable-tunneling" if _tunneling else "" + # if _recording and not tmp_dag.check_tunneling_enabled_config( + # enable_session_recording=_recording): + # command += f" --enable-connections-recording" if _recording else "" + # if _typescript_recording and not tmp_dag.check_tunneling_enabled_config( + # enable_typescript_recording=_typescript_recording): + # command += f" --enable-typescripts-recording" if _typescript_recording else "" + + print(f"{bcolors.FAIL}The settings are denied by PAM Configuration: {config_uid}. " + f"Please enable settings for the configuration by running\n" + f"{command}'{bcolors.ENDC}") + return - record_uid = kwargs.get('uid') - if not record_uid: - raise CommandError('tunnel Disable', '"record" argument is required') + if not tmp_dag.is_tunneling_config_set_up(record_uid): + tmp_dag.link_resource_to_config(record_uid) + + pam_settings = record.get_typed_field('pamSettings') + if not pam_settings: + pre_settings = {} + if _tunneling and tunneling_override_port: + pre_settings["portForward"]["port"] = tunneling_override_port + # if _connections and connection_override_port: + # pre_settings["connection"]["port"] = connection_override_port + if pre_settings: + pam_settings = vault.TypedField.new_field('pamSettings', pre_settings, "") + # TODO follow template + record.custom.append(pam_settings) + dirty = True + else: + if not tmp_dag.is_tunneling_config_set_up(record_uid): + tmp_dag.link_resource_to_config(record_uid) + if not pam_settings.value: + pam_settings.value.append({"connection": {}, "portForward": {}}) + # if _connections and connection_override_port: + # pam_settings.value[0]['connection']['port'] = connection_override_port + # dirty = True + if _tunneling and tunneling_override_port: + pam_settings.value[0]['portForward']['port'] = tunneling_override_port + dirty = True - record = vault.KeeperRecord.load(params, record_uid) + if _remove_tunneling_override_port and pam_settings.value[0]['portForward'].get('port'): + pam_settings.value[0]['portForward'].pop('port') + dirty = True + # if _remove_connections_override_port and pam_settings.value[0]['connection'].get('port'): + # pam_settings.value[0]['connection'].pop('port') + # dirty = True + if not tmp_dag.is_tunneling_config_set_up(record_uid): + print(f"{bcolors.FAIL}No PAM Configuration UID set. This must be set for tunneling to work. " + f"This can be done by running " + f"{bcolors.OKBLUE}'pam tunnel edit {record_uid} --config [ConfigUID] --enable-tunneling' " + f"{bcolors.FAIL}The ConfigUID can be found by running " + f"{bcolors.OKBLUE}'pam config list'{bcolors.ENDC}") + return + allowed_settings_name = "allowedSettings" + if record.record_type == "pamRemoteBrowser": + allowed_settings_name = "pamRemoteBrowserSettings" + + # if _recording is not None and tmp_dag.check_if_resource_allowed(record_uid, "sessionRecording") != _recording: + # dirty = True + # if _typescript_recording is not None and tmp_dag.check_if_resource_allowed(record_uid, "typescriptRecording") != _typescript_recording: + # dirty = True + # if _connections is not None and tmp_dag.check_if_resource_allowed(record_uid, "connections") != _connections: + # dirty = True + if _tunneling is not None and tmp_dag.check_if_resource_allowed(record_uid, "portForwards") != _tunneling: + dirty = True - if not isinstance(record, vault.TypedRecord): - print(f"{bcolors.FAIL}Record {record_uid} not found.{bcolors.ENDC}") - return + if dirty: + tmp_dag.set_resource_allowed(resource_uid=record_uid, tunneling=_tunneling, allowed_settings_name=allowed_settings_name) - pam_settings = record.get_typed_field('pamSettings') - if pam_settings: - if pam_settings.value[0]['portForward']['enabled']: - pam_settings.value[0]['portForward']['enabled'] = False - record_management.update_record(params, record) - api.sync_down(params) - print(f"{bcolors.OKGREEN}Tunneling disabled for {record_uid}{bcolors.ENDC}") + # Print out the tunnel settings + tmp_dag.print_tunneling_config(record_uid, record.get_typed_field('pamSettings'), config_uid) class PAMTunnelStartCommand(Command): @@ -2162,60 +2728,64 @@ def setup_logging(self, convo_id, log_queue, logging_level): logger.debug("Logging setup complete.") return logger - async def connect(self, params, record_uid, convo_num, gateway_uid, host, port, - log_queue, gateway_public_key_bytes, client_private_key): + async def connect(self, params, record_uid, convo_num, host, port, + log_queue, seed, target_host, target_port, socks): # Setup custom logging to put logs into log_queue logger = self.setup_logging(str(convo_num), log_queue, logging.getLogger().getEffectiveLevel()) - print(f"{bcolors.HIGHINTENSITYWHITE}Establishing tunnel between Commander and Gateway. Please wait...{bcolors.ENDC}") - # get the keys - gateway_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), gateway_public_key_bytes) + print(f"{bcolors.HIGHINTENSITYWHITE}Establishing tunnel between Commander and Gateway. Please wait..." + f"{bcolors.ENDC}") + # Symmetric key """ -# Generate an EC private key -private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() -) -# Serialize to PEM format -private_key_str = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() -).decode('utf-8') + Generate a 256-bit (32-byte) random seed + seed = os.urandom(32) """ - - client_private_key_pem = serialization.load_pem_private_key( - client_private_key.encode(), - password=None, + if isinstance(seed, str): + seed = base64_to_bytes(seed) + # Generate a 128-bit (16-byte) random nonce + nonce = os.urandom(MAIN_NONCE_LENGTH) + # Derive the encryption key using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=SYMMETRIC_KEY_LENGTH, # 256-bit key + salt=nonce, + info=b"KEEPER_TUNNEL_ENCRYPT_AES_GCM_128", backend=default_backend() - ) - - # Get symmetric key - symmetric_key = establish_symmetric_key(client_private_key_pem, gateway_public_key) + ).derive(seed) + symmetric_key = AESGCM(hkdf) # Set up the pc print_ready_event = asyncio.Event() kill_server_event = asyncio.Event() - pc = WebRTCConnection(params=params, record_uid=record_uid, gateway_uid=gateway_uid, - symmetric_key=symmetric_key, print_ready_event=print_ready_event, - kill_server_event=kill_server_event, logger=logger, server=params.server) + pc = WebRTCConnection(params=params, record_uid=record_uid, symmetric_key=symmetric_key, + print_ready_event=print_ready_event, kill_server_event=kill_server_event, + logger=logger, server=params.server) try: - await pc.signal_channel('start') + await pc.signal_channel('start', bytes_to_base64(nonce)) except Exception as e: raise CommandError('Tunnel Start', f"{e}") logger.debug("starting private tunnel") - private_tunnel = TunnelEntrance(host=host, port=port, pc=pc, print_ready_event=print_ready_event, logger=logger, - connect_task=params.tunnel_threads[convo_num].get("connect_task", None), - kill_server_event=kill_server_event) + if socks: + private_tunnel = SOCKS5Server(host=host, port=port, pc=pc, print_ready_event=print_ready_event, + logger=logger, + connect_task=params.tunnel_threads[convo_num].get("connect_task", None), + kill_server_event=kill_server_event, target_host=target_host, + target_port=target_port) + else: + private_tunnel = TunnelEntrance(host=host, port=port, pc=pc, print_ready_event=print_ready_event, + logger=logger, + connect_task=params.tunnel_threads[convo_num].get("connect_task", None), + kill_server_event=kill_server_event, target_host=target_host, + target_port=target_port) t1 = asyncio.create_task(private_tunnel.start_server()) params.tunnel_threads[convo_num].update({"server": t1, "entrance": private_tunnel, - "kill_server_event": kill_server_event}) + "kill_server_event": kill_server_event}) logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------") try: @@ -2225,10 +2795,11 @@ async def connect(self, params, record_uid, convo_num, gateway_uid, host, port, finally: logger.debug("--> STOP LISTENING FOR MESSAGES FROM GATEWAY --------") - def pre_connect(self, params, record_uid, convo_num, gateway_uid, host, port, - gateway_public_key_bytes, client_private_key): + def pre_connect(self, params, record_uid, convo_num, host, port, + seed, target_host, target_port, socks): tunnel_name = f"{convo_num}" - def custom_exception_handler(loop, context): + + def custom_exception_handler(_loop, context): # Check if the exception is present in the context if "exception" in context: exception = context["exception"] @@ -2254,12 +2825,13 @@ def custom_exception_handler(loop, context): params=params, record_uid=record_uid, convo_num=convo_num, - gateway_uid=gateway_uid, host=host, port=port, log_queue=output_queue, - gateway_public_key_bytes=gateway_public_key_bytes, - client_private_key=client_private_key + seed=seed, + target_host=target_host, + target_port=target_port, + socks=socks ) ) params.tunnel_threads[convo_num].update({"connect_task": connect_task}) @@ -2306,7 +2878,9 @@ def custom_exception_handler(loop, context): except Exception as e: logging.debug(f"{bcolors.WARNING}Exception while stopping event loop: {e}{bcolors.ENDC}") except Exception as e: - print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {tunnel_name}: {e}{bcolors.ENDC}") + print( + f"{bcolors.FAIL}An exception occurred in pre_connect for connection {tunnel_name}: {e}" + f"{bcolors.ENDC}") finally: clean_up_tunnel(params, convo_num) print(f"{bcolors.OKBLUE}Tunnel {tunnel_name} closed.{bcolors.ENDC}") @@ -2339,11 +2913,13 @@ def execute(self, params, **kwargs): port = find_open_port(tried_ports=[], preferred_port=port, host=host) except CommandError as e: print(f"{bcolors.FAIL}{e}{bcolors.ENDC}") + del params.tunnel_threads[convo_num] return else: port = find_open_port(tried_ports=[], host=host) if port is None: print(f"{bcolors.FAIL}Could not find open port to use for tunnel{bcolors.ENDC}") + del params.tunnel_threads[convo_num] return api.sync_down(params) @@ -2354,64 +2930,44 @@ def execute(self, params, **kwargs): pam_settings = record.get_typed_field('pamSettings') if not pam_settings: - print(f"{bcolors.FAIL}PAM Settings not enabled for record {record_uid}'.{bcolors.ENDC}") - print(f"{bcolors.WARNING}This is done by running 'pam tunnel enable {record_uid} " - f"--config [ConfigUID]' The ConfigUID can be found by running 'pam config list'{bcolors.ENDC}.") + print(f"{bcolors.FAIL}PAM Settings not configured for record {record_uid}'.{bcolors.ENDC}") + print(f"{bcolors.WARNING}This is done by running {bcolors.OKBLUE}'pam tunnel edit {record_uid} " + f"--enable-tunneling --config [ConfigUID]'" + f"{bcolors.WARNING} The ConfigUID can be found by running" + f"{bcolors.OKBLUE} 'pam config list'{bcolors.ENDC}.") return - try: - pam_info = pam_settings.value[0] - enabled_port_forward = pam_info.get("portForward", {}).get("enabled", False) - if not enabled_port_forward: - print(f"{bcolors.FAIL}PAM Settings not enabled for record {record_uid}. " - f"{bcolors.WARNING}This is done by running 'pam tunnel enable {record_uid}'.{bcolors.ENDC}") - return - except Exception as e: - print(f"{bcolors.FAIL}Error parsing PAM Settings for record {record_uid}: {e}{bcolors.ENDC}") + # SOCKS5 Proxy uses this to determine what connection to use for the tunnel + target = record.get_typed_field('pamHostname') + if not target: + print(f"{bcolors.FAIL}Hostname not found for record {record_uid}.{bcolors.ENDC}") return - - client_private_key = record.get_typed_field('trafficEncryptionKey') - if not client_private_key: - print(f"{bcolors.FAIL}Traffic Encryption Key not found for record {record_uid}.{bcolors.ENDC}") + target_host = target.get_default_value().get('hostName', None) + target_port = target.get_default_value().get('port', None) + if not target_host: + print(f"{bcolors.FAIL}Host not found for record {record_uid}.{bcolors.ENDC}") return - - client_private_key_value = client_private_key.get_default_value(str) - - configuration_uid = pam_info.get("configUid", None) - if not configuration_uid: - print(f"{bcolors.FAIL}Configuration UID not found for record {record_uid}.{bcolors.ENDC}") - return - configuration = vault.KeeperRecord.load(params, configuration_uid) - if not isinstance(configuration, vault.TypedRecord): - print(f"{bcolors.FAIL}Configuration {configuration_uid} not found.{bcolors.ENDC}") + if not target_port: + print(f"{bcolors.FAIL}Port not found for record {record_uid}.{bcolors.ENDC}") return - pam_resources = configuration.get_typed_field('pamResources') - if not pam_resources: - print(f"{bcolors.FAIL}PAM Resources not found for configuration {configuration_uid}.{bcolors.ENDC}") - return - if len(pam_resources.value) == 0: - print(f"{bcolors.FAIL}PAM Resources not found for configuration {configuration_uid}.{bcolors.ENDC}") - return - gateway_uid = '' - try: - gateway_uid = pam_resources.value[0].get("controllerUid", '') - except Exception as e: - print(f"{bcolors.FAIL}Error parsing PAM Resources for configuration {configuration_uid}: {e}{bcolors.ENDC}") - CommandError('Tunnel Start', f"{e}") - - if not gateway_uid: - print(f"{bcolors.FAIL}Gateway UID not found for configuration {configuration_uid}.{bcolors.ENDC}") - return + # IP or a CIDR subnet. + allowed_hosts = record.get_typed_field('multiline', 'Allowed Hosts') - gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils) + allowed_ports = record.get_typed_field('multiline', 'Allowed Ports') + socks = False + if allowed_hosts or allowed_ports: + socks = True - if not gateway_public_key_bytes: - print(f"{bcolors.FAIL}Could not retrieve public key for gateway {gateway_uid}{bcolors.ENDC}") + client_private_seed = record.get_typed_field('trafficEncryptionSeed') + if not client_private_seed: + print(f"{bcolors.FAIL}Traffic Encryption Seed not found for record {record_uid}.{bcolors.ENDC}") return + base64_seed = client_private_seed.get_default_value(str).encode('utf-8') + seed = base64_to_bytes(base64_seed) - t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_num, gateway_uid, host, port, - gateway_public_key_bytes, client_private_key_value) + t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_num, host, port, + seed, target_host, target_port, socks) ) # Setting the thread as a daemon thread @@ -2442,7 +2998,9 @@ def execute(self, params, **kwargs): def print_fail(con_num): con_name = '' - con_entrance = params.tunnel_threads[con_num].get("entrance", None) + con_entrance = None + if con_num in params.tunnel_threads: + con_entrance = params.tunnel_threads[con_num].get("entrance", None) fail_dynamic_length = len("| Endpoint ") + len(" failed to start..") if con_entrance: con_name = con_entrance.pc.endpoint_name diff --git a/keepercommander/commands/pam/config_helper.py b/keepercommander/commands/pam/config_helper.py index 4a888c45b..d2aded194 100644 --- a/keepercommander/commands/pam/config_helper.py +++ b/keepercommander/commands/pam/config_helper.py @@ -200,3 +200,16 @@ def record_rotation_get(params, record_uid_bytes): # type: (KeeperParams, bytes return rotation_info_rs +def configuration_controller_get(params, config_uid_bytes: bytes): + """ + Get the Controller UID that has access to the configuration UID + Retrieves a keeper.pam_controller record, from given configuration_uid provided in request. + controller_uid is the UID of the user who has access to the configuration url_safe_str_to_bytes(config_uid) + """ + rq = pam_pb2.PAMGenericUidRequest() + rq.uid = config_uid_bytes + + config_info_rs = api.communicate_rest(params, rq, 'pam/get_configuration_controller', + rs_type=pam_pb2.PAMController) + + return config_info_rs diff --git a/keepercommander/commands/pam/pam_dto.py b/keepercommander/commands/pam/pam_dto.py index 1814f4e3e..892b034dc 100644 --- a/keepercommander/commands/pam/pam_dto.py +++ b/keepercommander/commands/pam/pam_dto.py @@ -17,21 +17,63 @@ def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) -# ACTION INPUTS +# ACTION DISCOVER INPUTS + +class GatewayActionDiscoverJobStartInputs: + + def __init__(self, configuration_uid, user_map, shared_folder_uid, resource_uid=None, language="en", + # Settings + include_machine_dir_users=False, + include_azure_aadds=False, + skip_rules=False, + skip_machines=False, + skip_databases=False, + skip_directories=False, + skip_cloud_users=False, + credentials=None): + self.configurationUid = configuration_uid + self.resourceUid = resource_uid + self.userMap = user_map + self.sharedFolderUid = shared_folder_uid + self.language = language + self.includeMachineDirUsers = include_machine_dir_users + self.includeAzureAadds = include_azure_aadds + self.skipRules = skip_rules + self.skipMachines = skip_machines + self.skipDatabases = skip_databases + self.skipDirectories = skip_directories + self.skipCloudUsers = skip_cloud_users + + if credentials is None: + credentials = [] + self.credentials = credentials + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) -class GatewayActionDiscoverInputs: - def __init__(self, shared_folder_uid, provider_record_uid): - self.shared_folder_uid = shared_folder_uid - self.provider_record_uid = provider_record_uid +class GatewayActionDiscoverJobRemoveInputs: + + def __init__(self, configuration_uid, job_id): + self.configurationUid = configuration_uid + self.jobId = job_id def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) +class GatewayActionDiscoverRuleValidateInputs: + + def __init__(self, configuration_uid, statement): + self.configurationUid = configuration_uid + self.statement = statement + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + # ACTIONS + class GatewayAction(metaclass=abc.ABCMeta): def __init__(self, action, is_scheduled, gateway_destination=None, inputs=None, conversation_id=None): @@ -67,10 +109,29 @@ def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) -class GatewayActionDiscover(GatewayAction): +class GatewayActionDiscoverJobStart(GatewayAction): - def __init__(self, inputs: GatewayActionDiscoverInputs, conversation_id=None): - super().__init__('discover', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + def __init__(self, inputs: GatewayActionDiscoverJobStartInputs, conversation_id=None): + super().__init__('discover-job-start', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionDiscoverJobRemove(GatewayAction): + + def __init__(self, inputs: GatewayActionDiscoverJobRemoveInputs, conversation_id=None): + super().__init__('discover-job-remove', inputs=inputs, conversation_id=conversation_id, is_scheduled=True) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionDiscoverRuleValidate(GatewayAction): + + def __init__(self, inputs: GatewayActionDiscoverRuleValidateInputs, conversation_id=None): + super().__init__('discover-rule-validate', inputs=inputs, conversation_id=conversation_id, + is_scheduled=True) def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) @@ -118,8 +179,8 @@ def toJSON(self): class GatewayActionRotate(GatewayAction): def __init__(self, inputs: GatewayActionRotateInputs, conversation_id=None, gateway_destination=None): - super().__init__('rotate', inputs=inputs, conversation_id=conversation_id, gateway_destination=gateway_destination, - is_scheduled=True) + super().__init__('rotate', inputs=inputs, conversation_id=conversation_id, + gateway_destination=gateway_destination, is_scheduled=True) def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) diff --git a/keepercommander/commands/pam/router_helper.py b/keepercommander/commands/pam/router_helper.py index da17f9d03..f49fda7c5 100644 --- a/keepercommander/commands/pam/router_helper.py +++ b/keepercommander/commands/pam/router_helper.py @@ -18,7 +18,7 @@ from ...params import KeeperParams from ...proto import pam_pb2, router_pb2, connect_pb2 -VERIFY_SSL = True +VERIFY_SSL = bool(os.environ.get("VERIFY_SSL", "TRUE") == "TRUE") def get_router_url(params: KeeperParams): @@ -29,6 +29,9 @@ def get_router_url(params: KeeperParams): else: base_server_url = params.rest_context.server_base base_server = urlparse(base_server_url).netloc + str_base_server = base_server + if isinstance(base_server, bytes): + base_server = base_server.decode('utf-8') if base_server.lower().startswith('govcloud.'): base_server = base_server[len('govcloud.'):] @@ -73,8 +76,11 @@ def router_get_connected_gateways(params): # type: (KeeperParams) -> pam_pb2.PA # return None -def router_set_record_rotation_information(params, proto_request): - rs = _post_request_to_router(params, 'set_record_rotation', proto_request) +def router_set_record_rotation_information(params, proto_request, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + rs = _post_request_to_router(params, 'set_record_rotation', proto_request, transmission_key=transmission_key, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) return rs @@ -90,17 +96,21 @@ def router_get_relay_access_creds(params, expire_sec=None): return _post_request_to_router(params, 'relay_access_creds', query_params=query_params, rs_type=connect_pb2.RelayAccessCreds) -def _post_request_to_router(params, path, rq_proto=None, rs_type=None, method='post', raw_without_status_check_response=False, query_params=None): +def _post_request_to_router(params, path, rq_proto=None, rs_type=None, method='post', + raw_without_status_check_response=False, query_params=None, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): krouter_host = get_router_url(params) path = '/api/user/' + path - transmission_key = utils.generate_aes_key() - server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] + if not transmission_key: + transmission_key = utils.generate_aes_key() + if not encrypted_transmission_key: + server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] - if params.rest_context.server_key_id < 7: - encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) - else: - encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + if params.rest_context.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) encrypted_payload = b'' @@ -110,7 +120,8 @@ def _post_request_to_router(params, path, rq_proto=None, rs_type=None, method='p logging.debug('>>> [GW RQ] %s: %s', path, js) encrypted_payload = crypto.encrypt_aes_v2(rq_proto.SerializeToString(), transmission_key) - encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(params.session_token), transmission_key) + if not encrypted_session_token: + encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(params.session_token), transmission_key) try: rs = requests.request(method, @@ -206,7 +217,9 @@ def request_cookie_jar_to_str(cookie_jar): return ';'.join(found) -def router_send_action_to_gateway(params, gateway_action: GatewayAction, message_type, is_streaming, destination_gateway_uid_str=None, gateway_timeout=15000): +def router_send_action_to_gateway(params, gateway_action: GatewayAction, message_type, is_streaming, + destination_gateway_uid_str=None, gateway_timeout=15000, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): # Default time out how long the response from the Gateway should be krouter_host = get_router_url(params) @@ -251,24 +264,26 @@ def router_send_action_to_gateway(params, gateway_action: GatewayAction, message destination_gateway_uid_bytes = gateway_helper.find_connected_gateways(router_enterprise_controllers_connected, gateway_action.gateway_destination) destination_gateway_uid_str = utils.base64_url_encode(destination_gateway_uid_bytes) - msg_id = gateway_action.conversationId if gateway_action.conversationId else GatewayAction.generate_conversation_id() - msg_id_bytes = msg_id.encode('utf-8') + msg_id = gateway_action.conversationId if gateway_action.conversationId else GatewayAction.generate_conversation_id('true') rq = router_pb2.RouterControllerMessage() - rq.messageUid = msg_id_bytes + rq.messageUid = utils.base64_url_decode(msg_id) if isinstance(msg_id, str) else msg_id rq.controllerUid = destination_gateway_uid_bytes rq.messageType = message_type rq.streamResponse = is_streaming rq.payload = gateway_action.toJSON().encode('utf-8') rq.timeout = gateway_timeout - transmission_key = utils.generate_aes_key() + if not transmission_key: + transmission_key = utils.generate_aes_key() response = router_send_message_to_gateway( params=params, transmission_key=transmission_key, rq_proto=rq, - destination_gateway_uid_str=destination_gateway_uid_str) + destination_gateway_uid_str=destination_gateway_uid_str, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) rs_body = response.content @@ -317,22 +332,25 @@ def router_send_action_to_gateway(params, gateway_action: GatewayAction, message def router_send_message_to_gateway(params, transmission_key, rq_proto, destination_gateway_uid_str, - destination_gateway_cookies=None): + destination_gateway_cookies=None, encrypted_transmission_key=None, + encrypted_session_token=None): krouter_host = get_router_url(params) - server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] + if not encrypted_transmission_key: + server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] - if params.rest_context.server_key_id < 7: - encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) - else: - encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + if params.rest_context.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) encrypted_payload = b'' if rq_proto: encrypted_payload = crypto.encrypt_aes_v2(rq_proto.SerializeToString(), transmission_key) - encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(params.session_token), transmission_key) + if not encrypted_session_token: + encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(params.session_token), transmission_key) if not destination_gateway_cookies: destination_gateway_cookies = get_controller_cookie(params, destination_gateway_uid_str) @@ -359,6 +377,55 @@ def router_send_message_to_gateway(params, transmission_key, rq_proto, destinati return rs +def get_response_payload(router_response): + + router_response_response = router_response.get('response') + router_response_response_payload_str = router_response_response.get('payload') + router_response_response_payload_dict = json.loads(router_response_response_payload_str) + + return router_response_response_payload_dict + + +def get_dag_leafs(params, encrypted_session_token, encrypted_transmission_key, record_id): + """ + POST a stringified JSON object to /api/dag/get_leafs on the KRouter + The object is: + { + vertex: string, + graphId: number + } + """ + krouter_host = get_router_url(params) + path = '/api/user/get_leafs' + + payload = { + 'vertex': record_id, + 'graphId': 0 + } + + try: + rs = requests.request('post', + krouter_host + path, + verify=VERIFY_SSL, + headers={ + 'TransmissionKey': bytes_to_base64(encrypted_transmission_key), + 'Authorization': f'KeeperUser {bytes_to_base64(encrypted_session_token)}' + }, + data=json.dumps(payload).encode('utf-8') + ) + except ConnectionError as e: + raise KeeperApiError(-1, f"KRouter is not reachable on '{krouter_host}'. Error: ${e}") + except Exception as ex: + raise ex + + if rs.status_code == 200: + logging.debug("Found right host") + return rs.json() + else: + logging.warning("Looks like there is no such controller connected to the router.") + return None + + def print_router_response(router_response, response_type, original_conversation_id=None, is_verbose=False): if not router_response: return diff --git a/keepercommander/commands/pam/user_facade.py b/keepercommander/commands/pam/user_facade.py new file mode 100644 index 000000000..8fd80036e --- /dev/null +++ b/keepercommander/commands/pam/user_facade.py @@ -0,0 +1,91 @@ +from ...record_facades import TypedRecordFacade, string_getter, string_setter, boolean_getter, boolean_setter +from ...vault import TypedField +from typing import Optional + + +class PamUserRecordFacade(TypedRecordFacade): + _login_getter = string_getter('_login') + _login_setter = string_setter('_login') + _password_getter = string_getter('_password') + _password_setter = string_setter('_password') + _distinguishedName_getter = string_getter('_distinguishedName') + _distinguishedName_setter = string_setter('_distinguishedName') + _connectDatabase_getter = string_getter('_connectDatabase') + _connectDatabase_setter = string_setter('_connectDatabase') + _managed_getter = boolean_getter('_managed') + _managed_setter = boolean_setter('_managed') + _oneTimeCode_getter = string_getter('_oneTimeCode') + _oneTimeCode_setter = string_setter('_oneTimeCode') + + def __init__(self): + super(PamUserRecordFacade, self).__init__() + self._login = None # type: Optional[TypedField] + self._password = None # type: Optional[TypedField] + self._distinguishedName = None # type: Optional[TypedField] + self._connectDatabase = None # type: Optional[TypedField] + self._managed = None # type: Optional[TypedField] + self._oneTimeCode = None # type: Optional[TypedField] + + @property + def login(self): + return PamUserRecordFacade._login_getter(self) + + @login.setter + def login(self, value): + PamUserRecordFacade._login_setter(self, value) + + @property + def password(self): + return PamUserRecordFacade._password_getter(self) + + @password.setter + def password(self, value): + PamUserRecordFacade._password_setter(self, value) + + @property + def distinguishedName(self): + return PamUserRecordFacade._distinguishedName_getter(self) + + @distinguishedName.setter + def distinguishedName(self, value): + PamUserRecordFacade._distinguishedName_setter(self, value) + + @property + def connectDatabase(self): + return PamUserRecordFacade._connectDatabase_getter(self) + + @connectDatabase.setter + def connectDatabase(self, value): + PamUserRecordFacade._connectDatabase_setter(self, value) + + @property + def managed(self): + return PamUserRecordFacade._connectDatabase_getter(self) + + @managed.setter + def managed(self, value): + PamUserRecordFacade._managed_setter(self, value) + + @property + def oneTimeCode(self): + return PamUserRecordFacade._oneTimeCode_getter(self) + + @oneTimeCode.setter + def oneTimeCode(self, value): + PamUserRecordFacade._oneTimeCode_setter(self, value) + + def load_typed_fields(self): + if self.record: + self.record.type_name = 'pamUser' + for attr in ["login", "password", "distinguishedName", "connectDatabase", "managed", "oneTimeCode"]: + attr_prv = f"_{attr}" + value = next((x for x in self.record.fields if x.type == attr), None) + setattr(self, attr_prv, value) + if value is None: + value = TypedField.new_field(attr, '') + setattr(self, attr_prv, value) + self.record.fields.append(value) + else: + for attr in ["_login", "_password", "_distinguishedName", "_connectDatabase", "_managed", "_oneTimeCode"]: + setattr(self, attr, None) + super(PamUserRecordFacade, self).load_typed_fields() diff --git a/keepercommander/commands/pam_debug/__init__.py b/keepercommander/commands/pam_debug/__init__.py new file mode 100644 index 000000000..0f413c313 --- /dev/null +++ b/keepercommander/commands/pam_debug/__init__.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from ...utils import value_to_boolean +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + from keeper_dag.connection import ConnectionBase + + +def get_connection(params: KeeperParams) -> ConnectionBase: + if value_to_boolean(os.environ.get("USE_LOCAL_DAG", False)) is False: + from keeper_dag.connection.commander import Connection as CommanderConnection + return CommanderConnection(params=params) + else: + from keeper_dag.connection.local import Connection as LocalConnection + return LocalConnection() \ No newline at end of file diff --git a/keepercommander/commands/pam_debug/acl.py b/keepercommander/commands/pam_debug/acl.py new file mode 100644 index 000000000..a13335a8d --- /dev/null +++ b/keepercommander/commands/pam_debug/acl.py @@ -0,0 +1,144 @@ +from __future__ import annotations +import argparse +import logging +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext, PAM_USER +from ...display import bcolors +from ... import vault +from discovery_common.record_link import RecordLink +from discovery_common.types import UserAcl +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMDebugACLCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + + parser.add_argument('--user-uid', '-u', required=True, dest='user_uid', action='store', + help='User UID.') + parser.add_argument('--parent-uid', '-r', required=True, dest='parent_uid', action='store', + help='Resource or Configuration UID.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + def get_parser(self): + return PAMDebugACLCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + user_uid = kwargs.get("user_uid") + parent_uid = kwargs.get("parent_uid") + debug_level = int(kwargs.get("debug_level", 0)) + + print("") + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + record_link = RecordLink(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + + user_record = vault.KeeperRecord.load(params, user_uid) # type: Optional[TypedRecord] + if user_record is None: + print(f"{bcolors.FAIL}The user record does not exists.{bcolors.ENDC}") + return + + print(f"{bcolors.BOLD}The user record is {user_record.title}{bcolors.ENDC}") + + if user_record.record_type != PAM_USER: + print(f"{bcolors.FAIL}The user record is not a PAM User record.{bcolors.ENDC}") + return + + parent_record = vault.KeeperRecord.load(params, parent_uid) # type: Optional[TypedRecord] + if parent_record is None: + print(f"{bcolors.FAIL}The parent record does not exists.{bcolors.ENDC}") + return + + print(f"{bcolors.BOLD}The parent record is {parent_record.title}{bcolors.ENDC}") + + if parent_record.record_type.startswith("pam") is False: + print(f"{bcolors.FAIL}The parent record is not a PAM record.{bcolors.ENDC}") + return + + if parent_record.record_type == PAM_USER: + print(f"{bcolors.FAIL}The parent record cannot be a PAM User record.{bcolors.ENDC}") + return + + parent_is_config = parent_record.record_type.endswith("Configuration") + + # Get the ACL between the user and the parent. + # It might not exist. + acl_exists = True + acl = record_link.get_acl(user_uid, parent_uid) + if acl is None: + print("No existing ACL, creating an ACL.") + acl = UserAcl() + acl_exists = False + + # Make sure the ACL for cloud user is set. + if parent_is_config is True: + print("Is an IAM user.") + acl.is_iam_user = True + + rl_parent_vertex = record_link.dag.get_vertex(parent_uid) + if rl_parent_vertex is None: + print("Parent record linking vertex did not exists, creating one.") + rl_parent_vertex = record_link.dag.add_vertex(parent_uid) + + rl_user_vertex = record_link.dag.get_vertex(user_uid) + if rl_user_vertex is None: + print("User record linking vertex did not exists, creating one.") + rl_user_vertex = record_link.dag.add_vertex(user_uid) + + has_admin_uid = record_link.get_admin_record_uid(parent_uid) + if has_admin_uid is not None: + print("Parent record already has an admin.") + else: + print("Parent record does not have an admin.") + + belongs_to_vertex = record_link.acl_has_belong_to_record_uid(user_uid) + if belongs_to_vertex is None: + print("User record does not belong to any resource, or provider.") + else: + if belongs_to_vertex.active is False: + print("User record belongs to an inactive parent.") + else: + print("User record belongs to another record.") + + print("") + + while True: + res = input(f"Does this user belong to {parent_record.title} Y/N >").lower() + if res == "y": + acl.belongs_to = True + break + elif res == "n": + acl.belongs_to = False + break + + if has_admin_uid is None: + while True: + res = input(f"Is this user the admin of {parent_record.title} Y/N >").lower() + if res == "y": + acl.is_admin = True + break + elif res == "n": + acl.is_admin = False + break + + try: + record_link.belongs_to(user_uid, parent_uid, acl=acl) + record_link.save() + print(f"{bcolors.OKGREEN}Updated/added ACL between {user_record.title} and " + f"{parent_record.title}{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Could not update ACL: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/pam_debug/alter.py b/keepercommander/commands/pam_debug/alter.py new file mode 100644 index 000000000..d4e718e90 --- /dev/null +++ b/keepercommander/commands/pam_debug/alter.py @@ -0,0 +1,30 @@ +from __future__ import annotations +import argparse +import os +from ..discover import PAMGatewayActionDiscoverCommandBase +from ...display import bcolors +from ... import vault +from discovery_common.infrastructure import Infrastructure +from discovery_common.record_link import RecordLink +from discovery_common.types import UserAcl, DiscoveryObject +from keeper_dag import EdgeType +from importlib.metadata import version +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMDebugAlterCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=False, dest='gateway', action='store', + help='Gateway name or UID.') + + def get_parser(self): + return PAMDebugAlterCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + pass \ No newline at end of file diff --git a/keepercommander/commands/pam_debug/gateway.py b/keepercommander/commands/pam_debug/gateway.py new file mode 100644 index 000000000..df4e78cd3 --- /dev/null +++ b/keepercommander/commands/pam_debug/gateway.py @@ -0,0 +1,92 @@ +from __future__ import annotations +import argparse +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext +from .graph import PAMDebugGraphCommand +from ...display import bcolors +from discovery_common.infrastructure import Infrastructure +from discovery_common.record_link import RecordLink +from discovery_common.user_service import UserService +from discovery_common.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + + +class PAMDebugGatewayCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + + def get_parser(self): + return PAMDebugGatewayCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + debug_level = kwargs.get("debug_level", False) + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + infra = Infrastructure(record=gateway_context.configuration, params=params, fail_on_corrupt=False) + infra.load() + + record_link = RecordLink(record=gateway_context.configuration, params=params, fail_on_corrupt=False) + user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False) + + if gateway_context is None: + print(f" {self._f('Cannot get gateway information. Gateway may not be up.')}") + return + + print("") + print(self._h("Gateway Information")) + print(f" {self._b('Gateway UID')}: {gateway_context.gateway_uid}") + print(f" {self._b('Gateway Name')}: {gateway_context.gateway_name}") + if gateway_context.configuration is not None: + print(f" {self._b('Configuration UID')}: {gateway_context.configuration_uid}") + print(f" {self._b('Configuration Title')}: {gateway_context.configuration.title}") + print(f" {self._b('Configuration Key Bytes Hex')}: {gateway_context.configuration.record_key.hex()}") + else: + print(f" {self._f('The gateway appears to not have a configuration.')}") + print("") + + graph = PAMDebugGraphCommand() + + if infra.dag.has_graph is True: + print(self._h("Infrastructure Graph")) + graph.do_list(params=params, gateway_context=gateway_context, graph_type="infra", debug_level=debug_level, + indent=1) + else: + print(f"{self._f('The gateway configuration does not have a infrastructure graph.')}") + + print("") + + if record_link.dag.has_graph is True: + print(self._h("Record Linking Graph")) + graph.do_list(params=params, gateway_context=gateway_context, graph_type="rl", debug_level=debug_level, + indent=1) + else: + print(f"{self._f('The gateway configuration does not have a record linking graph.')}") + + print("") + + if user_service.dag.has_graph is True: + print(self._h("User to Service/Task Graph")) + graph.do_list(params=params, gateway_context=gateway_context, graph_type="service", debug_level=debug_level, + indent=1) + else: + print(f"{self._f('The gateway configuration does not have a user to service/task graph.')}") + + print("") diff --git a/keepercommander/commands/pam_debug/graph.py b/keepercommander/commands/pam_debug/graph.py new file mode 100644 index 000000000..f4b05c64d --- /dev/null +++ b/keepercommander/commands/pam_debug/graph.py @@ -0,0 +1,647 @@ +from __future__ import annotations +from . import get_connection +import argparse +import logging +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ...display import bcolors +from ... import vault +from discovery_common.infrastructure import Infrastructure +from discovery_common.record_link import RecordLink +from discovery_common.user_service import UserService +from discovery_common.jobs import Jobs +from discovery_common.constants import (PAM_USER, PAM_DIRECTORY, PAM_MACHINE, PAM_DATABASE, VERTICES_SORT_MAP, + DIS_INFRA_GRAPH_ID, RECORD_LINK_GRAPH_ID, USER_SERVICE_GRAPH_ID, + DIS_JOBS_GRAPH_ID) +from discovery_common.types import (DiscoveryObject, DiscoveryUser, DiscoveryDirectory, DiscoveryMachine, + DiscoveryDatabase, JobContent) +from discovery_common.dag_sort import sort_infra_vertices +from keeper_dag import DAG +from keeper_dag.connection.commander import Connection as CommanderConnection +from keeper_dag.connection.local import Connection as LocalConnection +from keeper_dag.vertex import DAGVertex +from keeper_dag.edge import DAGEdge +from typing import Optional, Union, TYPE_CHECKING + +Connection = Union[CommanderConnection, LocalConnection] +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMDebugGraphCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + parser.add_argument('--type', '-t', required=True, choices=['infra', 'rl', 'service', 'jobs'], + dest='graph_type', action='store', help='Graph type', default='infra') + parser.add_argument('--raw', required=False, dest='raw', action='store_true', + help='Render raw graph. Will render corrupt graphs.') + + parser.add_argument('--list', required=False, dest='do_text_list', action='store_true', + help='List items in a list.') + + parser.add_argument('--render', required=False, dest='do_render', action='store_true', + help='Render a graph') + parser.add_argument('--file', '-f', required=False, dest='filepath', action='store', + default="keeper_graph", help='Base name for the graph file.') + parser.add_argument('--format', required=False, choices=['raw', 'dot', 'twopi', 'patchwork'], + dest='format', default="dot", action='store', help='The format of the graph.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + mapping = { + PAM_USER: {"order": 1, "sort": "_sort_name", "item": DiscoveryUser, "key": "user"}, + PAM_DIRECTORY: {"order": 1, "sort": "_sort_name", "item": DiscoveryDirectory, "key": "host_port"}, + PAM_MACHINE: {"order": 2, "sort": "_sort_host", "item": DiscoveryMachine, "key": "host"}, + PAM_DATABASE: {"order": 3, "sort": "_sort_host", "item": DiscoveryDatabase, "key": "host_port"}, + } + + graph_id_map = { + "infra": DIS_INFRA_GRAPH_ID, + "rl": RECORD_LINK_GRAPH_ID, + "service": USER_SERVICE_GRAPH_ID, + "jobs": DIS_JOBS_GRAPH_ID + } + + def get_parser(self): + return PAMDebugGraphCommand.parser + + def _do_text_list_infra(self, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + infra.load(sync_point=0) + + try: + configuration = infra.get_root.has_vertices()[0] + except (Exception,): + print(f"{bcolors.FAIL}Could not find the configuration in the infrastructure graph. " + f"Has discovery been run for this gateway?{bcolors.ENDC}") + + return + + line_start = { + 0: "", + 1: "* ", + 2: "- ", + } + + color_func = { + 0: self._h, + 1: self._gr, + 2: self._p, + 3: self._b + } + + def _handle(current_vertex: DAGVertex, indent: int = 0, last_record_type: Optional[str] = None): + + if current_vertex.active is False: + return + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + text = "" + ls = line_start.get(indent, " ") + cf = color_func.get(indent, self._p) + + if current_vertex.active is False: + text += f"{pad}{current_vertex.uid} " + self._f("(Inactive)") + elif current_vertex.corrupt is False: + current_content = DiscoveryObject.get_discovery_object(current_vertex) + if current_content.record_uid is None: + text += f"{pad}{ls}{current_vertex.uid}; {current_content.title} does not have a record." + else: + record = vault.KeeperRecord.load(params, current_content.record_uid) # type: Optional[TypedRecord] + if record is not None: + text += f"{pad}{ls}" + cf(f"{current_vertex.uid}; {record.title}; {record.record_uid}") + else: + text += f"{pad}{ls}" + cf(f"{current_vertex.uid}; {current_content.title}; " + + self._f("have record uid, record does not exists, " + "might have to sync.")) + else: + text += f"{pad}{current_vertex.uid} " + self._f("(Corrupt)") + + print(text) + + record_type_to_vertices_map = sort_infra_vertices(current_vertex) + # Process the record type by their map order in ascending order. + + # Sort the record types by their order in the constant. + # 'order' is an int. + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + for vertex in record_type_to_vertices_map[record_type]: + if last_record_type is None or last_record_type != record_type: + if indent == 0: + print(f"{pad} {self._b(self._n(record_type))}") + last_record_type = record_type + + _handle(vertex, indent=indent+1) + + print("") + _handle(configuration, indent=indent) + print("") + + def _do_text_list_rl(self, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + print("") + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + record_link = RecordLink(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + configuration = record_link.dag.get_root + + record = vault.KeeperRecord.load(params, configuration.uid) # type: Optional[TypedRecord] + if record is None: + print(self._f("Configuration record does not exists.")) + return + + print(self._h(f"{pad}{record.record_type}, {record.title}, {record.record_uid}")) + + if configuration.has_data is True: + try: + data = configuration.content_as_dict + print(f"{pad} . data") + for k, v in data.items(): + print(f"{pad} + {k} = {v}") + except (Exception,): + print(f"{pad} ! data not JSON") + + def _group(configuration_vertex: DAGVertex) -> dict: + + group = { + PAM_USER: [], + PAM_DIRECTORY: [], + PAM_MACHINE: [], + PAM_DATABASE: [], + "NO_RECORD": [] + } + + for vertex in configuration_vertex.has_vertices(): + record = vault.KeeperRecord.load(params, vertex.uid) # type: Optional[TypedRecord] + if record is None: + group["NO_RECORD"].append({ + "v": vertex + }) + continue + group[record.record_type].append({ + "v": vertex, + "r": record + }) + + return group + + group = _group(configuration) + + for record_type in [PAM_USER, PAM_DIRECTORY, PAM_MACHINE, PAM_DATABASE]: + if len(group[record_type]) > 0: + print(f"{pad} " + self._b(self._n(record_type))) + for item in group[record_type]: + vertex = item.get("v") # type: DAGVertex + record = item.get("r") # type: TypedRecord + text = self._gr(f"{record.title}; {record.record_uid}") + if vertex.active is False: + text += " " + self._f("Inactive") + print(f"{pad} * {text}") + + # These are cloud users + if record_type == PAM_USER: + acl = record_link.get_acl(vertex.uid, configuration.uid) + if acl is None: + print(f"{pad} {self._f('missing ACL')}") + else: + if acl.is_admin is True: + print(f"{pad} . is the {self._b('Admin')}") + if acl.belongs_to is True: + print(f"{pad} . belongs to this resource") + else: + print(f"{pad} . looks like directory user") + continue + + if vertex.has_data is True: + try: + data = vertex.content_as_dict + print(f"{pad} . data") + for k, v in data.items(): + print(f"{pad} + {k} = {v}") + except (Exception,): + print(f"{pad} ! data not JSON") + + children = vertex.has_vertices() + if len(children) > 0: + bad = [] + for child in children: + child_record = vault.KeeperRecord.load(params, child.uid) # type: Optional[TypedRecord] + if child_record is None: + if child.active is True: + bad.append(self._f(f"- Record UID {child.uid} does not exists.")) + continue + else: + print(f"{pad} - {child_record.title}; {child_record.record_uid}") + acl = record_link.get_acl(child.uid, vertex.uid) + if acl is None: + print(f"{pad} {self._f('missing ACL')}") + else: + if acl.is_admin is True: + print(f"{pad} . is the {self._b('Admin')}") + if acl.belongs_to is True: + print(f"{pad} . belongs to this resource") + else: + print(f"{pad} . looks like directory user") + + if child.has_data is True: + try: + data = child.content_as_dict + print(f"{pad} . data") + for k, v in data.items(): + print(f"{pad} + {k} = {v}") + except (Exception,): + print(f"{pad} ! data not JSON") + for i in bad: + print("{pad} " + i) + + def _do_text_list_service(self, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + user_service = UserService(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + configuration = user_service.dag.get_root + + def _handle(current_vertex: DAGVertex, parent_vertex: Optional[DAGVertex] = None, indent: int = 0): + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + "* " + + record = vault.KeeperRecord.load(params, current_vertex.uid) # type: Optional[TypedRecord] + if record is None: + if current_vertex.active is False: + print(f"{pad}Record {current_vertex.uid} does not exists, inactive in the graph.") + else: + print(f"{pad}Record {current_vertex.uid} does not exists, active in the graph.") + return + elif current_vertex.active is False: + print(f"{pad}{record.record_type}, {record.title}, {record.record_uid} exists, " + "inactive in the graph.") + return + + acl_text = "" + acl = user_service.get_acl(parent_vertex, current_vertex) + if acl is not None: + acl_text = self._f("None") + acl_parts = [] + if acl.is_service is True: + acl_parts.append(self._bl("Service")) + if acl.is_task is True: + acl_parts.append(self._bl("Task")) + if len(acl_parts) > 0: + acl_text = ", ".join(acl_parts) + acl_text = f"- {acl_text}" + + print(f"{pad}{record.record_type}, {record.title}, {record.record_uid}{acl_text}") + + for vertex in current_vertex.has_vertices(): + _handle(current_vertex=vertex, parent_vertex=current_vertex, indent=indent+1) + + _handle(current_vertex=configuration, parent_vertex=None, indent=indent) + + def _do_text_list_jobs(self, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, + indent: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level, fail_on_corrupt=False) + infra.load(sync_point=0) + + pad = "" + if indent > 0: + pad = "".ljust(2 * indent, ' ') + "* " + + conn = get_connection(params) + graph_sync = DAG(conn=conn, record=gateway_context.configuration, logger=logging, debug_level=debug_level, + graph_id=DIS_JOBS_GRAPH_ID) + graph_sync.load(0) + configuration = graph_sync.get_root + vertices = configuration.has_vertices() + if len(vertices) == 0: + print(self._f(f"The jobs graph has not been initialized. Only has root vertex.")) + return + + vertex = vertices[0] + if vertex.has_data is False: + print(self._f(f"The job vertex does not contain any data")) + return + + current_json = vertex.content_as_str + if current_json is None: + print(self._f(f"The current job vertex content is None")) + return + + content = JobContent.model_validate_json(current_json) + print(f"{pad}{self._b('Active Job ID')}: {content.active_job_id}") + print("") + print(f"{pad}{self._h('History')}") + print("") + for job in content.job_history: + print(f"{pad} --------------------------------------") + print(f"{pad} Job Id: {job.job_id}") + print(f"{pad} Started: {job.start_ts_str}") + print(f"{pad} Ended: {job.end_ts_str}") + print(f"{pad} Duration: {job.duration_sec_str}") + print(f"{pad} Infra Sync Point: {job.sync_point}") + if job.success is True: + print(f"{pad} Status: {self._gr('Success')}") + else: + print(f"{pad} Status: {self._f('Fail')}") + if job.error is not None: + print(f"{pad} Error: {self._gr(job.error)}") + + print("") + + if job.delta is None: + print(f"{pad}{self._f('The job is missing a delta, never finished discovery.')}") + else: + if len(job.delta.added) > 0: + print(f"{pad} {self._h('Added')}") + for added in job.delta.added: + vertex = infra.dag.get_vertex(added.uid) + if vertex is None: + print(f"{pad} * Vertex {added.uid} does not exists.") + else: + if vertex.active is False: + print(f"{pad} * Vertex {added.uid} is inactive.") + elif vertex.corrupt is True: + print(f"{pad} * Vertex {added.uid} is corrupt.") + else: + content = DiscoveryObject.get_discovery_object(vertex) + print(f"{pad} * {content.description}; Record UID: {content.record_uid}") + print("") + + if len(job.delta.changed) > 0: + print(f"{pad} {self._h('Changed')}") + for changed in job.delta.changed: + vertex = infra.dag.get_vertex(changed.uid) + if vertex is None: + print(f"{pad} * Vertex {changed.uid} does not exists.") + else: + if vertex.active is False: + print(f"{pad} * Vertex {changed.uid} is inactive.") + elif vertex.corrupt is True: + print(f"{pad} * Vertex {changed.uid} is corrupt.") + else: + content = DiscoveryObject.get_discovery_object(vertex) + print(f"{pad} * {content.description}; Record UID: {content.record_uid}") + if changed.changes is not None: + for k, v in changed.changes.items(): + print(f"{pad} {k} = {v}") + print("") + + if len(job.delta.deleted) > 0: + print(f"{pad} {self._h('Deleted')}") + for deleted in job.delta.deleted: + print(f"{pad} * Removed vertex {deleted.uid}.") + print("") + + def _do_render_infra(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, graph_format: str, + debug_level: int = 0): + + infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + infra.load(sync_point=0) + + print("") + dot_instance = infra.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + print(dot_instance) + else: + try: + dot_instance.render(filepath) + print(f"Infrastructure graph rendered to {self._gr(filepath)}") + except Exception as err: + print(self._f(f"Could not generate graph: {err}")) + raise err + print("") + + def _do_render_rl(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, graph_format: str, + debug_level: int = 0): + + rl = RecordLink(record=gateway_context.configuration, params=params, logger=logging, debug_level=debug_level) + + print("") + dot_instance = rl.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + print(dot_instance) + else: + try: + dot_instance.render(filepath) + print(f"Record linking graph rendered to {self._gr(filepath)}") + except Exception as err: + print(self._f(f"Could not generate graph: {err}")) + raise err + print("") + + def _do_render_service(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, debug_level: int = 0): + + service = UserService(record=gateway_context.configuration, params=params, logger=logging, + debug_level=debug_level) + + print("") + dot_instance = service.to_dot( + graph_type=graph_format if graph_format != "raw" else "dot", + show_only_active_vertices=False, + show_only_active_edges=False + ) + if graph_format == "raw": + print(dot_instance) + else: + try: + dot_instance.render(filepath) + print(f"User service/tasks graph rendered to {self._gr(filepath)}") + except Exception as err: + print(self._f(f"Could not generate graph: {err}")) + raise err + print("") + + def _do_render_jobs(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, debug_level: int = 0): + + jobs = Jobs(record=gateway_context.configuration, params=params, logger=logging, debug_level=debug_level) + + print("") + dot_instance = jobs.dag.to_dot() + if graph_format == "raw": + print(dot_instance) + else: + try: + dot_instance.render(filepath) + print(f"Job graph rendered to {self._gr(filepath)}") + except Exception as err: + print(self._f(f"Could not generate graph: {err}")) + raise err + print("") + + def _do_raw_text_list(self, params: KeeperParams, gateway_context: GatewayContext, graph_id: int = 0, + debug_level: int = 0): + + logging.debug(f"loading graph id {graph_id}, for record uid {gateway_context.configuration.record_uid}") + + conn = get_connection(params=params) + dag = DAG(conn=conn, record=gateway_context.configuration, graph_id=graph_id, fail_on_corrupt=False, + logger=logging, debug_level=debug_level) + dag.load(sync_point=0) + print("") + if dag.is_corrupt is True: + print(f"{bcolors.FAIL}The graph is corrupt at Vertex UIDs: {', '.join(dag.corrupt_uids)}") + print("") + + logging.debug("DAG DOT -------------------------------") + logging.debug(str(dag.to_dot())) + logging.debug("DAG DOT -------------------------------") + + line_start = { + 0: "", + 1: "* ", + 2: "- ", + 3: ". ", + } + + color_func = { + 0: self._h, + 1: self._gr, + 2: self._bl, + 3: self._p + } + + def _handle(current_vertex: DAGVertex, last_vertex: Optional[DAGVertex] = None, indent: int = 0): + + pad = "" + if indent > 0: + pad = "".ljust(4 * indent, ' ') + + ls = line_start.get(indent, " ") + cf = color_func.get(indent, self._p) + text = f"{pad}{ls}{cf(current_vertex.uid)}" + + edge_types = [] + if last_vertex is not None: + for edge in current_vertex.edges: # type: DAGEdge + if edge.active is False: + continue + if edge.head_uid == last_vertex.uid: + edge_types.append(edge.edge_type.value) + if len(edge_types) > 0: + text += f"; edges: {', '.join(edge_types)}" + + if current_vertex.active is False: + text += " " + self._f("Inactive") + if current_vertex.corrupt is True: + text += " " + self._f("Corrupt") + + print(text) + + if current_vertex.active is False: + logging.debug(f"vertex {current_vertex.uid} is not active, will not get children.") + return + + vertices = current_vertex.has_vertices() + if len(vertices) == 0: + logging.debug(f"vertex {current_vertex.uid} does not have any children.") + return + + for vertex in vertices: + _handle(vertex, current_vertex, indent=indent + 1) + + print("") + _handle(dag.get_root) + print("") + + def _do_raw_render_graph(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, + graph_format: str, graph_id: int = 0, debug_level: int = 0): + + conn = get_connection(params=params) + dag = DAG(conn=conn, record=gateway_context.configuration, graph_id=graph_id, fail_on_corrupt=False, + logger=logging, debug_level=debug_level) + dag.load(sync_point=0) + dot = dag.to_dot(graph_format=graph_format) + if graph_format == "raw": + print(dot) + else: + try: + dot.render(filepath) + print(f"Graph rendered to {self._gr(filepath)}") + except Exception as err: + print(self._f(f"Could not generate graph: {err}")) + raise err + + print("") + + def do_list(self, params: KeeperParams, gateway_context: GatewayContext, graph_type: str, debug_level: int = 0, + indent: int = 0): + list_func = getattr(self, f"_do_text_list_{graph_type}") + list_func(params=params, + gateway_context=gateway_context, + debug_level=debug_level, + indent=indent) + + def execute(self, params: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + raw = kwargs.get("raw", False) + graph_type = kwargs.get("graph_type") + do_text_list = kwargs.get("do_text_list") + do_render = kwargs.get("do_render") + debug_level = int(kwargs.get("debug_level", 0)) + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + if raw is True: + if do_text_list is True: + self._do_raw_text_list(params=params, + gateway_context=gateway_context, + graph_id=PAMDebugGraphCommand.graph_id_map.get(graph_type), + debug_level=debug_level) + if do_render is True: + filepath = kwargs.get("filepath") + graph_format = kwargs.get("format") + self._do_raw_render_graph(params=params, + gateway_context=gateway_context, + filepath=filepath, + graph_format=graph_format, + graph_id=PAMDebugGraphCommand.graph_id_map.get(graph_type), + debug_level=debug_level) + else: + if do_text_list is True: + self.do_list( + params=params, + gateway_context=gateway_context, + graph_type=graph_type, + debug_level=debug_level + ) + if do_render is True: + filepath = kwargs.get("filepath") + graph_format = kwargs.get("format") + render_func = getattr(self, f"_do_render_{graph_type}") + render_func(params=params, + gateway_context=gateway_context, + filepath=filepath, + graph_format=graph_format, + debug_level=debug_level) diff --git a/keepercommander/commands/pam_debug/info.py b/keepercommander/commands/pam_debug/info.py new file mode 100644 index 000000000..118d3f2d7 --- /dev/null +++ b/keepercommander/commands/pam_debug/info.py @@ -0,0 +1,436 @@ +from __future__ import annotations +import argparse +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ...display import bcolors +from ... import vault +from discovery_common.infrastructure import Infrastructure +from discovery_common.record_link import RecordLink +from discovery_common.user_service import UserService +from discovery_common.types import UserAcl, DiscoveryObject +from discovery_common.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from keeper_dag import EdgeType +import time +import re +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMDebugInfoCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + # The record to base everything on. + parser.add_argument('--record-uid', '-i', required=True, dest='record_uid', action='store', + help='Keeper PAM record UID.') + + def get_parser(self): + return PAMDebugInfoCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + + record_uid = kwargs.get("record_uid") + record = vault.KeeperRecord.load(params, record_uid) # type: Optional[TypedRecord] + if record is None: + print(f"{bcolors.FAIL}Record does not exists.{bcolors.ENDC}") + return + + if record.record_type not in ["pamUser", "pamMachine", "pamDatabase", "pamDirectory"]: + if re.search(r'^pam.*Configuration$', record.record_type) is None: + print(f"{bcolors.FAIL}The record is a {record.record_type}. This is not a PAM record.{bcolors.ENDC}") + return + + record_rotation = params.record_rotation_cache.get(record_uid) + if record_rotation is None: + print(f"{bcolors.FAIL}PAM record does not have rotation settings.{bcolors.ENDC}") + return + + # TODO: Not sure if this is going away. If not we are going to have to scan the graphs. + controller_uid = record_rotation.get("configuration_uid") + if controller_uid is None: + print(f"{bcolors.FAIL}Record does not have the PAM Configuration set.{bcolors.ENDC}") + return + + configuration_record = vault.KeeperRecord.load(params, controller_uid) # type: Optional[TypedRecord] + + gateway_context = GatewayContext.from_configuration_uid(params, controller_uid) + + infra = Infrastructure(record=configuration_record, params=params) + infra.load() + record_link = RecordLink(record=configuration_record, params=params) + user_service = UserService(record=configuration_record, params=params) + + print("") + print(self._h("Record Information")) + print(f" {self._b('Record UID')}: {record_uid}") + print(f" {self._b('Record Title')}: {record.title}") + print(f" {self._b('Record Type')}: {record.record_type}") + print(f" {self._b('Configuration UID')}: {configuration_record.record_uid}") + print(f" {self._b('Configuration Key Bytes Hex')}: {configuration_record.record_key.hex()}") + if gateway_context is not None: + print(f" {self._b('Gateway Name')}: {gateway_context.gateway_name}") + print(f" {self._b('Gateway UID')}: {gateway_context.gateway_uid}") + else: + print(f" {self._f('Cannot get gateway information. Gateway may not be up.')}") + print("") + + discovery_vertices = infra.dag.search_content({"record_uid": record.record_uid}) + record_vertex = record_link.dag.get_vertex(record.record_uid) + + if record_vertex is not None: + print(self._h("Record Linking")) + record_parent_vertices = record_vertex.belongs_to_vertices() + print(self._b(" Parent Records")) + if len(record_parent_vertices) > 0: + for record_parent_vertex in record_parent_vertices: + + parent_record = vault.KeeperRecord.load(params, + record_parent_vertex.uid) # type: Optional[TypedRecord] + if parent_record is None: + print(f"{bcolors.FAIL} * Parent record {record_parent_vertex.uid} " + f"does not exists.{bcolors.ENDC}") + continue + + acl_edge = record_vertex.get_edge(record_parent_vertex, EdgeType.ACL) + if acl_edge is not None: + acl_content = acl_edge.content_as_object(UserAcl) + print(f" * ACL to {self._n(parent_record.record_type)}; {parent_record.title}; " + f"{record_parent_vertex.uid}") + if acl_content.is_admin is True: + print(f" . Is {self._gr('Admin')}") + if acl_content.belongs_to is True: + print(f" . Belongs") + else: + print(f" . Is {self._bl('Remote user')}") + + link_edge = record_vertex.get_edge(record_parent_vertex, EdgeType.LINK) + if link_edge is not None: + print(f" * LINK to {self._n(parent_record.record_type)}; {parent_record.title}; " + f"{record_parent_vertex.uid}") + else: + # This really should not happen + print(f"{bcolors.FAIL} Record does not have a parent record.{bcolors.ENDC}") + print("") + + record_child_vertices = record_vertex.has_vertices() + print(self._b(" Child Records")) + if len(record_child_vertices) > 0: + for record_child_vertex in record_child_vertices: + child_record = vault.KeeperRecord.load(params, + record_child_vertex.uid) # type: Optional[TypedRecord] + + if child_record is None: + print(f"{bcolors.FAIL} * Child record {record_child_vertex.uid} " + f"does not exists.{bcolors.ENDC}") + continue + + acl_edge = record_child_vertex.get_edge(record_vertex, EdgeType.ACL) + link_edge = record_child_vertex.get_edge(record_vertex, EdgeType.LINK) + if acl_edge is not None: + acl_content = acl_edge.content_as_object(UserAcl) + print(f" * ACL from {self._n(child_record.record_type)}; {child_record.title}; " + f"{record_child_vertex.uid}") + if acl_content.is_admin is True: + print(f" . Is {self._gr('Admin')}") + if acl_content.belongs_to is True: + print(f" . Belongs") + else: + print(f" . Is {self._bl('Remote user')}") + elif link_edge is not None: + print(f" * LINK from {self._n(child_record.record_type)}; {child_record.title}; " + "{record_child_vertex.uid}") + else: + for edge in record_vertex.edges: # List[DAGEdge] + print(f" * {self._f(edge.edge_type)}?") + + else: + # This is OK + print(f" Record does not have any children.") + print("") + + else: + print(f"{bcolors.FAIL}Cannot find record in record linking.{bcolors.ENDC}") + + # Only PAM User and PAM Machine can have services and tasks. + # This is really only Windows machines. + if record.record_type == PAM_USER or record.record_type == PAM_MACHINE: + + # Get the user to service/task vertex. + user_service_vertex = user_service.dag.get_vertex(record_uid) + + if user_service_vertex is not None: + + # If the record is a PAM User + if record.record_type == PAM_USER: + + user_results = { + "is_task": [], + "is_service": [] + } + + # Get a list of all the resources the user is the username/password on service/task. + for us_machine_vertex in user_service.get_resource_vertices(record_uid): + + # Get the resource record + us_machine_record = ( + vault.KeeperRecord.load(params, us_machine_vertex.uid)) # type: Optional[TypedRecord] + + acl = user_service.get_acl(us_machine_vertex.uid, user_service_vertex.uid) + for attr in ["is_task", "is_service"]: + value = getattr(acl, attr) + if value is True: + + # If the resource record does not exist. + if us_machine_record is None: + + # Default the title to Unknown (in red). + # See if we have an infrastructure vertex with this record UID. + # If we do have it, use the title inside the first vertex's data content. + title = self._f("Unknown") + infra_resource_vertices = infra.dag.search_content( + {"record_uid": us_machine_vertex.uid}) + if len(infra_resource_vertices) > 0: + infra_resource_vertex = infra_resource_vertices[0] + if infra_resource_vertex.has_data is True: + content = DiscoveryObject.get_discovery_object(infra_resource_vertex) + title = content.title + + user_results[attr].append(f" * Record {us_machine_vertex.uid}, " + f"{title} does not exists.") + + # Record exists; just use information from the record. + else: + user_results[attr].append(f" * {us_machine_record.title}, " + f"{us_machine_vertex.uid}") + + print(f"{bcolors.HEADER}Service on Machines{bcolors.ENDC}") + if len(user_results["is_service"]) > 0: + for service in user_results["is_service"]: + print(service) + else: + print(" PAM User is not used for any services.") + print("") + + print(f"{bcolors.HEADER}Scheduled Tasks on Machines{bcolors.ENDC}") + if len(user_results["is_task"]) > 0: + for task in user_results["is_task"]: + print(task) + else: + print(" PAM User is not used for any scheduled tasks.") + print("") + + # If the record is a PAM Machine + else: + user_results = { + "is_task": [], + "is_service": [] + } + + # Get the users that are used for tasks/services on this machine. + for us_user_vertex in user_service.get_user_vertices(record_uid): + + us_user_record = vault.KeeperRecord.load(params, + us_user_vertex.uid) # type: Optional[TypedRecord] + acl = user_service.get_acl(user_service_vertex.uid, us_user_vertex.uid) + for attr in ["is_task", "is_service"]: + value = getattr(acl, attr) + if value is True: + + # If the user record does not exist. + if us_user_record is None: + + # Default the title to Unknown (in red). + # See if we have an infrastructure vertex with this record UID. + # If we do have it, use the title inside the first vertex's data content. + title = self._f("Unknown") + infra_resource_vertices = infra.dag.search_content( + {"record_uid": us_user_vertex.uid}) + if len(infra_resource_vertices) > 0: + infra_resource_vertex = infra_resource_vertices[0] + if infra_resource_vertex.has_data is True: + content = DiscoveryObject.get_discovery_object(infra_resource_vertex) + title = content.title + + user_results[attr].append(f" * Record {us_user_vertex.uid}, " + f"{title} does not exists.") + + # Record exists; just use information from the record. + else: + user_results[attr].append(f" * {us_user_record.title}, " + f"{us_user_vertex.uid}") + + print(f"{bcolors.HEADER}Users that are used for Services{bcolors.ENDC}") + if len(user_results["is_service"]) > 0: + for service in user_results["is_service"]: + print(service) + else: + print(" Machine does not use any non-builtin users for services.") + print("") + + print(f"{bcolors.HEADER}Users that are used for Scheduled Tasks{bcolors.ENDC}") + if len(user_results["is_task"]) > 0: + for task in user_results["is_task"]: + print(task) + else: + print(" Machine does not use any non-builtin users for scheduled tasks.") + print("") + else: + print(self._f("There are no services or schedule tasks associated with this record.")) + print("") + try: + if len(discovery_vertices) == 0: + print(f"{bcolors.FAIL}Could not find any discovery infrastructure vertices for " + f"{record.record_uid}{bcolors.ENDC}") + elif len(discovery_vertices) > 0: + + if len(discovery_vertices) > 1: + print(f"{bcolors.FAIL}Found multiple vertices with the record UID of " + f"{record.record_uid}{bcolors.ENDC}") + for vertex in discovery_vertices: + print(f" * Infrastructure Vertex UID: {vertex.uid}") + print("") + + discovery_vertex = discovery_vertices[0] + content = DiscoveryObject.get_discovery_object(discovery_vertex) + + missing_since = "NA" + if content.missing_since_ts is not None: + missing_since = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(content.missing_since_ts)) + + print(self._h("Discovery Object Information")) + print(f" {self._b('Vertex UID')}: {content.uid}") + print(f" {self._b('Object ID')}: {content.id}") + print(f" {self._b('Record UID')}: {content.record_uid}") + print(f" {self._b('Parent Record UID')}: {content.parent_record_uid}") + print(f" {self._b('Shared Folder UID')}: {content.shared_folder_uid}") + print(f" {self._b('Record Type')}: {content.record_type}") + print(f" {self._b('Object Type')}: {content.object_type_value}") + print(f" {self._b('Ignore Object')}: {content.ignore_object}") + print(f" {self._b('Rule Engine Result')}: {content.action_rules_result}") + print(f" {self._b('Name')}: {content.name}") + print(f" {self._b('Generated Title')}: {content.title}") + print(f" {self._b('Generated Description')}: {content.description}") + print(f" {self._b('Missing Since')}: {missing_since}") + print(f" {self._b('Discovery Notes')}:") + for note in content.notes: + print(f" * {note}") + if content.error is not None: + print(f"{bcolors.FAIL} Error: {content.error}{bcolors.ENDC}") + if content.stacktrace is not None: + print(f"{bcolors.FAIL} Stack Trace:{bcolors.ENDC}") + print(f"{bcolors.FAIL}{content.stacktrace}{bcolors.ENDC}") + print("") + print(f"{bcolors.HEADER}Record Type Specifics{bcolors.ENDC}") + + if record.record_type == PAM_USER: + print(f" {self._b('User')}: {content.item.user}") + print(f" {self._b('DN')}: {content.item.dn}") + print(f" {self._b('Database')}: {content.item.database}") + print(f" {self._b('Active')}: {content.item.active}") + print(f" {self._b('Expired')}: {content.item.expired}") + print(f" {self._b('Source')}: {content.item.source}") + elif record.record_type == PAM_MACHINE: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Operating System')}: {content.item.os}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Is the Gateway')}: {content.item.is_gateway}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + print("") + # If facts are not set, inside discover may not have been performed for the machine. + if content.item.facts.id is not None and content.item.facts.name is not None: + print(f" {self._b('Machine Name')}: {content.item.facts.name}") + print(f" {self._b('Machine ID')}: {content.item.facts.id.machine_id}") + print(f" {self._b('Product ID')}: {content.item.facts.id.product_id}") + print(f" {self._b('Board Serial')}: {content.item.facts.id.board_serial}") + print(f" {self._b('Directories')}:") + if content.item.facts.directories is not None and len(content.item.facts.directories) > 0: + for directory in content.item.facts.directories: + print(f" * Directory Domain: {directory.domain}") + print(f" Software: {directory.software}") + print(f" Login Format: {directory.login_format}") + else: + print(" Machines is not using any directories.") + + print("") + print(f" {self._b('Services')} (Non Builtin Users):") + if len(content.item.facts.services) > 0: + for service in content.item.facts.services: + print(f" * {service.name} = {service.user}") + else: + print(" Machines has no services that are using non-builtin users.") + + print(f" {self._b('Scheduled Tasks')} (Non Builtin Users)") + if len(content.item.facts.tasks) > 0: + for task in content.item.facts.tasks: + print(f" * {task.name} = {task.user}") + else: + print(" Machines has no schedules tasks that are using non-builtin users.") + else: + print(f"{bcolors.FAIL} Machine facts are not set. Discover inside may not have been " + f"performed.{bcolors.ENDC}") + elif record.record_type == PAM_DATABASE: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Database Type')}: {content.item.type}") + print(f" {self._b('Database')}: {content.item.database}") + print(f" {self._b('Use SSL')}: {content.item.use_ssl}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + elif record.record_type == PAM_DIRECTORY: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Directory Type')}: {content.item.type}") + print(f" {self._b('Use SSL')}: {content.item.use_ssl}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + + print("") + print(self._h("Belongs To Vertices (Parents)")) + vertices = discovery_vertex.belongs_to_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = discovery_vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + + if len(vertices) == 0: + print(f"{bcolors.FAIL} Does not belong to anyone{bcolors.ENDC}") + + print("") + print(f"{bcolors.HEADER}Vertices Belonging To (Children){bcolors.ENDC}") + vertices = discovery_vertex.has_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(discovery_vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + if len(vertices) == 0: + print(f" Does not have any children.") + + print("") + else: + print(f"{bcolors.FAIL}Could not find infrastructure vertex.{bcolors.ENDC}") + except Exception as err: + print(f"{bcolors.FAIL}Could not get information on infrastructure: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/pam_debug/verify.py b/keepercommander/commands/pam_debug/verify.py new file mode 100644 index 000000000..85645a461 --- /dev/null +++ b/keepercommander/commands/pam_debug/verify.py @@ -0,0 +1,56 @@ +from __future__ import annotations +from . import get_connection +import logging +import argparse +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ...display import bcolors +from ...vault import TypedRecord +from discovery_common.verify import Verify +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + from ...vault import KeeperRecord + + +class PAMDebugVerifyCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID.') + parser.add_argument('--fix', required=False, dest='fix', action='store_true', + help='Fix all problems.') + parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', + help='GraphSync debug level. Default is 0', type=int, default=0) + + def get_parser(self): + return PAMDebugVerifyCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + fix = kwargs.get("fix", False) + debug_level = kwargs.get("debug_level", False) + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + def _record_lookup(record_uid: str) -> KeeperRecord: + return TypedRecord.load(params, record_uid) + + colors = { + Verify.OK: bcolors.OKGREEN, + Verify.FAIL: bcolors.FAIL, + Verify.UNK: bcolors.OKBLUE, + Verify.TITLE: bcolors.BOLD, + Verify.COLOR_RESET: bcolors.ENDC + } + + verify = Verify(record=gateway_context.configuration, logger=logging, debug_level=debug_level, + output=sys.stdout, params=params, colors=colors) + verify.run(fix=fix, + lookup_record_func=_record_lookup) diff --git a/keepercommander/commands/pam_debug/version.py b/keepercommander/commands/pam_debug/version.py new file mode 100644 index 000000000..cfbb879da --- /dev/null +++ b/keepercommander/commands/pam_debug/version.py @@ -0,0 +1,20 @@ +from __future__ import annotations +import argparse +from ..discover import PAMGatewayActionDiscoverCommandBase +from ...display import bcolors +from importlib.metadata import version +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + + +class PAMDebugVersionCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='dr-pam-command-debug') + + def execute(self, params: KeeperParams, **kwargs): + + print("") + print(f"{bcolors.BOLD}keeper-dag version:{bcolors.ENDC} {version('keeper-dag')}") + print(f"{bcolors.BOLD}discovery-common version:{bcolors.ENDC} {version('discovery-common')}") + print("") \ No newline at end of file diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index e24c7762f..08513ec3b 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -6,31 +6,35 @@ import secrets import socket import string +import struct import time from datetime import datetime from typing import Optional, Dict from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.kdf.hkdf import HKDF -from cryptography.utils import int_to_bytes -from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string, string_to_bytes, \ - bytes_to_int +from keeper_dag import DAG, EdgeType +from keeper_dag.connection.commander import Connection +from keeper_dag.types import RefType +from keeper_dag.vertex import DAGVertex +from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string, string_to_bytes +from keepercommander import crypto, utils, rest_api from keepercommander.commands.pam.pam_dto import GatewayActionWebRTCSession -from keepercommander.commands.pam.router_helper import router_get_relay_access_creds, router_send_action_to_gateway +from keepercommander.commands.pam.router_helper import router_get_relay_access_creds, router_send_action_to_gateway, \ + get_dag_leafs from keepercommander.display import bcolors from keepercommander.error import CommandError from keepercommander.params import KeeperParams from keepercommander.proto import pam_pb2 +from keepercommander.vault import PasswordRecord logging.getLogger('aiortc').setLevel(logging.WARNING) logging.getLogger('aioice').setLevel(logging.WARNING) READ_TIMEOUT = 10 NONCE_LENGTH = 12 +MAIN_NONCE_LENGTH = 16 SYMMETRIC_KEY_LENGTH = RANDOM_LENGTH = 32 MESSAGE_MAX = 5 @@ -38,13 +42,13 @@ CONTROL_MESSAGE_NO_LENGTH = 2 CLOSE_CONNECTION_REASON_LENGTH = 1 TIME_STAMP_LENGTH = 8 -CONNECTION_NO_LENGTH = DATA_LENGTH = 4 +CONNECTION_NO_LENGTH = DATA_LENGTH = PORT_LENGTH = 4 TERMINATOR = b';' PROTOCOL_LENGTH = CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + CONTROL_MESSAGE_NO_LENGTH + len(TERMINATOR) -KRELAY_URL = 'KRELAY_URL' -ALT_KRELAY_URL = 'KRELAY_SERVER' +KRELAY_URL = 'KRELAY_SERVER' +GATEWAY_TIMEOUT = int(os.getenv('GATEWAY_TIMEOUT')) if os.getenv('GATEWAY_TIMEOUT') else 30000 -# WebRTC constants +# WebRTC constant values # 16 MiB max https://viblast.com/blog/2015/2/25/webrtc-bufferedamount/, so we will use 14.4 MiB or 90% of the max, # because in some cases if the max is reached the channel will close BUFFER_THRESHOLD = 134217728 * .90 @@ -65,6 +69,19 @@ class CloseConnectionReasons(enum.IntEnum): ConnectionLost = 9 ConnectionFailed = 10 TunnelClosed = 11 + AdminClosed = 12 + +class ConversationType(enum.Enum): + TUNNEL = "tunnel" + SSH = "ssh" + RDP = "rdp" + VNC = "vnc" + HTTP = "http" + KUBERNETES = "kubernetes" + TELNET = "telnet" + MYSQL = "mysql" + SQLSERVER = "sql-server" + POSTGRESQL = "postgresql" class ConnectionNotFoundException(Exception): @@ -151,41 +168,441 @@ def is_port_open(host: str, port: int) -> bool: return False -def establish_symmetric_key(private_key, client_public_key): - # Perform ECDH key agreement - shared_secret = private_key.exchange(ec.ECDH(), client_public_key) - - # Derive a symmetric key using HKDF - symmetric_key = HKDF( - algorithm=hashes.SHA256(), - length=SYMMETRIC_KEY_LENGTH, - salt=None, - info=b'encrypt network traffic', - ).derive(shared_secret) - return AESGCM(symmetric_key) - - def tunnel_encrypt(symmetric_key: AESGCM, data: bytes): """ Encrypts data using the symmetric key """ + # Compress the data nonce = os.urandom(NONCE_LENGTH) # 12-byte nonce for AES-GCM - d = nonce + symmetric_key.encrypt(nonce, data, None) - return bytes_to_base64(d) + encrypted_data = symmetric_key.encrypt(nonce, data, None) + return bytes_to_base64(nonce + encrypted_data) def tunnel_decrypt(symmetric_key: AESGCM, encrypted_data: str): """ Decrypts data using the symmetric key """ - data_bytes = base64_to_bytes(encrypted_data) - if len(data_bytes) <= NONCE_LENGTH: + + mixed_data = base64_to_bytes(encrypted_data) + # Data may be compressed and base64 encoded + + if len(mixed_data) <= NONCE_LENGTH: return None - nonce = data_bytes[:NONCE_LENGTH] - data = data_bytes[NONCE_LENGTH:] + nonce = mixed_data[:NONCE_LENGTH] + encrypted_data = mixed_data[NONCE_LENGTH:] + try: - return symmetric_key.decrypt(nonce, data, None) + return symmetric_key.decrypt(nonce, encrypted_data, None) except Exception as e: logging.error(f'Error decrypting data: {e}') return None +def get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid): + # try to get config from dag + try: + rs = get_dag_leafs(params, encrypted_session_token, encrypted_transmission_key, record_uid) + # response: "[{\"type\":\"rec\",\"value\":\"Jagbt2dxrft_91FovB5dwg\",\"name\":null}]" + if not rs: + return None + else: + return rs[0].get('value', '') + except Exception as e: + print(f"{bcolors.FAIL}Error getting configuration: {e}{bcolors.ENDC}") + return None + + +def get_keeper_tokens(params): + transmission_key = generate_random_bytes(32) + server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] + + if params.rest_context.server_key_id < 7: + encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) + else: + encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) + encrypted_session_token = crypto.encrypt_aes_v2( + utils.base64_url_decode(params.session_token), transmission_key) + + return encrypted_session_token, encrypted_transmission_key, transmission_key + + +class TunnelDAG: + def __init__(self, params, encrypted_session_token, encrypted_transmission_key, record_uid: str, is_config=False): + config_uid = None + if not is_config: + config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) + if not config_uid: + config_uid = record_uid + self.record = PasswordRecord() + self.record.record_uid = config_uid + self.record.record_key = generate_random_bytes(32) + self.encrypted_session_token = encrypted_session_token + self.encrypted_transmission_key = encrypted_transmission_key + self.conn = Connection(params=params, encrypted_transmission_key=self.encrypted_transmission_key, + encrypted_session_token=self.encrypted_session_token + ) + self.linking_dag = DAG(conn=self.conn, record=self.record, graph_id=0) + try: + self.linking_dag.load() + except Exception as e: + logging.debug(f"Error loading config: {e}") + + def get_vertex_content(self, vertex): + return_content = None + if vertex is None: + return return_content + try: + return_content = vertex.content_as_dict + except Exception as e: + logging.debug(f"Error getting vertex content: {e}") + return_content = None + return return_content + + def resource_belongs_to_config(self, resource_uid): + if not self.linking_dag.has_graph: + return False + resource_vertex = self.linking_dag.get_vertex(resource_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + return resource_vertex and config_vertex.has(resource_vertex, EdgeType.LINK) + + def user_belongs_to_config(self, user_uid): + if not self.linking_dag.has_graph: + return False + user_vertex = self.linking_dag.get_vertex(user_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + res_content = False + if user_vertex and config_vertex and config_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(config_vertex, EdgeType.ACL) + _content = acl_edge.content_as_dict + res_content = _content.get('belongs_to', False) if _content else False + return res_content + + def check_tunneling_enabled_config(self, enable_connections=None, enable_tunneling=None, + enable_rotation=None, enable_session_recording=None, + enable_typescript_recording=None, remote_browser_isolation=None): + if not self.linking_dag.has_graph: + return False + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + content = self.get_vertex_content(config_vertex) + if content is None or not content.get('allowedSettings'): + return False + + allowed_settings = content['allowedSettings'] + if enable_connections and not allowed_settings.get("connections"): + return False + if enable_tunneling and not allowed_settings.get("portForwards"): + return False + if enable_rotation and not allowed_settings.get("rotation"): + return False + if allowed_settings.get("connections") and allowed_settings["connections"]: + if enable_session_recording and not allowed_settings.get("sessionRecording"): + return False + if enable_typescript_recording and not allowed_settings.get("typescriptRecording"): + return False + if remote_browser_isolation and not allowed_settings.get("remoteBrowserIsolation"): + return False + return True + + def edit_tunneling_config(self, connections=None, tunneling=None, rotation=None, session_recording=None, + typescript_recording=None, remote_browser_isolation=None): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid, vertex_type=RefType.PAM_NETWORK) + + if config_vertex.vertex_type != RefType.PAM_NETWORK: + config_vertex.vertex_type = RefType.PAM_NETWORK + content = self.get_vertex_content(config_vertex) + if content and content.get('allowedSettings'): + allowed_settings = dict(content['allowedSettings']) + del content['allowedSettings'] + content = {'allowedSettings': allowed_settings} + + if content is None: + content = {'allowedSettings': {}} + if 'allowedSettings' not in content: + content['allowedSettings'] = {} + + allowed_settings = content['allowedSettings'] + dirty = False + + if connections is not None and connections != allowed_settings.get("connections", False): + allowed_settings["connections"] = connections + dirty = True + if tunneling is not None and tunneling != allowed_settings.get("portForwards", False): + allowed_settings["portForwards"] = tunneling + dirty = True + # We default rotation to True + if rotation is not None and rotation != allowed_settings.get("rotation", True): + allowed_settings["rotation"] = rotation + dirty = True + if session_recording is not None and session_recording != allowed_settings.get("sessionRecording", False): + allowed_settings["sessionRecording"] = session_recording + dirty = True + if (typescript_recording is not None and + typescript_recording != allowed_settings.get("typescriptRecording", False)): + allowed_settings["typescriptRecording"] = typescript_recording + dirty = True + + if remote_browser_isolation is not None and remote_browser_isolation != allowed_settings.get("remoteBrowserIsolation", False): + allowed_settings["remoteBrowserIsolation"] = remote_browser_isolation + dirty = True + + if dirty: + config_vertex.add_data(content=content, path='meta', needs_encryption=False) + self.linking_dag.save() + + def get_all_owners(self, uid): + owners = [] + if self.linking_dag.has_graph: + vertex = self.linking_dag.get_vertex(uid) + if vertex: + owners = [owner.uid for owner in vertex.belongs_to_vertices()] + return owners + + def user_belongs_to_resource(self, user_uid, resource_uid): + user_vertex = self.linking_dag.get_vertex(user_uid) + resource_vertex = self.linking_dag.get_vertex(resource_uid) + res_content = False + if user_vertex and resource_vertex and resource_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + _content = acl_edge.content_as_dict + res_content = _content.get('belongs_to', False) if _content else False + return res_content + + def get_resource_uid(self, user_uid): + if not self.linking_dag.has_graph: + return None + resources = self.get_all_owners(user_uid) + if len(resources) > 0: + for resource in resources: + if self.user_belongs_to_resource(user_uid, resource): + return resource + return None + + def link_resource_to_config(self, resource_uid): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid) + + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + resource_vertex = self.linking_dag.add_vertex(uid=resource_uid) + + if not config_vertex.has(resource_vertex, EdgeType.LINK): + resource_vertex.belongs_to(config_vertex, EdgeType.LINK) + self.linking_dag.save() + + def link_user_to_config(self, user_uid): + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + config_vertex = self.linking_dag.add_vertex(uid=self.record.record_uid) + self.link_user(user_uid, config_vertex, belongs_to=True, is_iam_user=True) + + def link_user_to_resource(self, user_uid, resource_uid, is_admin=None, belongs_to=None): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None or not self.resource_belongs_to_config(resource_uid): + print(f"{bcolors.FAIL}Resource {resource_uid} does not belong to the configuration{bcolors.ENDC}") + return False + self.link_user(user_uid, resource_vertex, is_admin, belongs_to) + + def link_user(self, user_uid, source_vertex: DAGVertex, is_admin=None, belongs_to=None, is_iam_user=None): + + user_vertex = self.linking_dag.get_vertex(user_uid) + if user_vertex is None: + user_vertex = self.linking_dag.add_vertex(uid=user_uid, vertex_type=RefType.PAM_USER) + + content = {} + dirty = False + if belongs_to is not None: + content["belongs_to"] = bool(belongs_to) + if is_admin is not None: + content["is_admin"] = bool(is_admin) + if is_iam_user is not None: + content["is_iam_user"] = bool(is_iam_user) + + if user_vertex.vertex_type != RefType.PAM_USER: + user_vertex.vertex_type = RefType.PAM_USER + + if source_vertex.has(user_vertex, EdgeType.ACL): + acl_edge = user_vertex.get_edge(source_vertex, EdgeType.ACL) + existing_content = acl_edge.content_as_dict + for key in existing_content: + if key not in content: + content[key] = existing_content[key] + if content != existing_content: + dirty = True + + if dirty: + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=content) + # user_vertex.add_data(content=content, needs_encryption=False) + self.linking_dag.save() + else: + user_vertex.belongs_to(source_vertex, EdgeType.ACL, content=content) + self.linking_dag.save() + + def get_all_admins(self): + if not self.linking_dag.has_graph: + return [] + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + if config_vertex is None: + return [] + admins = [] + for user_vertex in config_vertex.has_vertices(EdgeType.ACL): + acl_edge = user_vertex.get_edge(config_vertex, EdgeType.ACL) + if acl_edge: + content = acl_edge.content_as_dict + if content.get('is_admin'): + admins.append(user_vertex.uid) + return admins + + def check_if_resource_has_admin(self, resource_uid): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + return False + for user_vertex in resource_vertex.has_vertices(EdgeType.ACL): + acl_edge = user_vertex.get_edge(resource_vertex, EdgeType.ACL) + if acl_edge: + content = acl_edge.content_as_dict + if content.get('is_admin'): + return user_vertex.uid + return False + + def check_if_resource_allowed(self, resource_uid, setting): + resource_vertex = self.linking_dag.get_vertex(resource_uid) + content = self.get_vertex_content(resource_vertex) + return content.get('allowedSettings', {}).get(setting, False) if content else False + + def set_resource_allowed(self, resource_uid, tunneling=None, connections=None, rotation=None, + session_recording=None, typescript_recording=None, + allowed_settings_name='allowedSettings', is_config=False, + v_type: RefType=str(RefType.PAM_MACHINE)): + v_type = RefType(v_type) + allowed_ref_types = [RefType.PAM_MACHINE, RefType.PAM_DATABASE, RefType.PAM_DIRECTORY, RefType.PAM_BROWSER] + if v_type not in allowed_ref_types: + # default to machine + v_type = RefType.PAM_MACHINE + + resource_vertex = self.linking_dag.get_vertex(resource_uid) + if resource_vertex is None: + resource_vertex = self.linking_dag.add_vertex(uid=resource_uid, vertex_type=v_type) + + if resource_vertex.vertex_type not in allowed_ref_types: + resource_vertex.vertex_type = v_type + if is_config: + resource_vertex.vertex_type = RefType.PAM_NETWORK + dirty = False + content = self.get_vertex_content(resource_vertex) + if content is None: + content = {allowed_settings_name: {}} + dirty = True + if allowed_settings_name not in content: + content[allowed_settings_name] = {} + dirty = True + + settings = content[allowed_settings_name] + if tunneling is not None and tunneling != settings.get("portForwards", False): + settings["portForwards"] = tunneling + dirty = True + if connections is not None and connections != settings.get("connections", False): + settings["connections"] = connections + dirty = True + # We default rotation to True + if rotation is not None and rotation != settings.get("rotation", True): + settings["rotation"] = rotation + dirty = True + if session_recording is not None and session_recording != settings.get("sessionRecording", False): + settings["sessionRecording"] = session_recording + dirty = True + if typescript_recording is not None and typescript_recording != settings.get("typescriptRecording", False): + settings["typescriptRecording"] = typescript_recording + dirty = True + + if dirty: + resource_vertex.add_data(content=content, path='meta', needs_encryption=False) + self.linking_dag.save() + + def is_tunneling_config_set_up(self, resource_uid): + if not self.linking_dag.has_graph: + return False + resource_vertex = self.linking_dag.get_vertex(resource_uid) + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + return resource_vertex and config_vertex and config_vertex in resource_vertex.belongs_to_vertices() + + def remove_from_dag(self, uid): + if not self.linking_dag.has_graph: + return True + + vertex = self.linking_dag.get_vertex(uid) + if vertex is None: + return True + + vertex.delete() + self.linking_dag.save(confirm=True) + + def print_tunneling_config(self, record_uid, pam_settings=None, config_uid=None): + if not pam_settings and not config_uid: + return + self.linking_dag.load() + vertex = self.linking_dag.get_vertex(record_uid) + content = self.get_vertex_content(vertex) + config_id = config_uid if config_uid else pam_settings.value[0].get('configUid') + if content and content.get('allowedSettings'): + allowed_settings = content['allowedSettings'] + print(f"{bcolors.OKGREEN}Settings configured for {record_uid}{bcolors.ENDC}") + # connections = f"{bcolors.OKBLUE}Enabled" if allowed_settings.get('connections') else \ + # f"{bcolors.WARNING}Disabled" + port_forwarding = f"{bcolors.OKBLUE}Enabled" if allowed_settings.get('portForwards') else \ + f"{bcolors.WARNING}Disabled" + rotation = f"{bcolors.WARNING}Disabled" if (allowed_settings.get('rotation') and not allowed_settings['rotation']) else f"{bcolors.OKBLUE}Enabled" + print(f"{bcolors.OKGREEN}\tRotation: {rotation}{bcolors.ENDC}") + # print(f"{bcolors.OKGREEN}\tConnections: {connections}{bcolors.ENDC}") + # if config_id == record_uid: + # rbi = f"{bcolors.OKBLUE}Enabled" if allowed_settings.get('remoteBrowserIsolation') else \ + # f"{bcolors.WARNING}Disabled" + # print(f"{bcolors.OKGREEN}\tRemote Browser Isolation: {rbi}{bcolors.ENDC}") + print(f"{bcolors.OKGREEN}\tTunneling: {port_forwarding}{bcolors.ENDC}") + # if allowed_settings.get('connections'): + # if allowed_settings.get('sessionRecording'): + # print(f"{bcolors.OKGREEN}\tSession Recording: {bcolors.OKBLUE}Enabled{bcolors.ENDC}") + # else: + # print(f"{bcolors.OKGREEN}\tSession Recording: {bcolors.WARNING}Disabled{bcolors.ENDC}") + # if allowed_settings.get('typescriptRecording'): + # print(f"{bcolors.OKGREEN}\tTypescript Recording: {bcolors.OKBLUE}Enabled{bcolors.ENDC}") + # else: + # print(f"{bcolors.OKGREEN}\tTypescript Recording: {bcolors.WARNING}Disabled{bcolors.ENDC}") + # admin_uid = self.check_if_resource_has_admin(record_uid) + # if admin_uid: + # print(f"{bcolors.OKGREEN}\tAdmin: {bcolors.OKBLUE}{admin_uid}{bcolors.ENDC}") + + print(f"{bcolors.OKGREEN}Configuration: {config_id} {bcolors.ENDC}") + if config_id is not None: + config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + config_content = self.get_vertex_content(config_vertex) + if config_content and config_content.get('allowedSettings'): + config_allowed_settings = config_content['allowedSettings'] + # config_connections = f"{bcolors.OKBLUE}Enabled" if config_allowed_settings.get('connections') else \ + # f"{bcolors.WARNING}Disabled" + # + # config_rbi = f"{bcolors.OKBLUE}Enabled" if config_allowed_settings.get('remoteBrowserIsolation') else \ + # f"{bcolors.WARNING}Disabled" + config_port_forwarding = f"{bcolors.OKBLUE}Enabled" if ( + config_allowed_settings.get('portForwards')) else \ + f"{bcolors.WARNING}Disabled" + config_rotation = f"{bcolors.WARNING}Disabled" if (config_allowed_settings.get('rotation') and + not config_allowed_settings['rotation']) else \ + f"{bcolors.OKBLUE}Enabled" + print(f"{bcolors.OKGREEN}\tRotation: {config_rotation}{bcolors.ENDC}") + # print(f"{bcolors.OKGREEN}\tConnections: {config_connections}{bcolors.ENDC}") + # print(f"{bcolors.OKGREEN}\tRemote Browser Isolation: {config_rbi}{bcolors.ENDC}") + print(f"{bcolors.OKGREEN}\tTunneling: {config_port_forwarding}{bcolors.ENDC}") + # + # if config_allowed_settings.get('connections') and config_allowed_settings['connections']: + # if config_allowed_settings.get('sessionRecording'): + # print(f"{bcolors.OKGREEN}\tSession Recording: {bcolors.OKBLUE}Enabled{bcolors.ENDC}") + # else: + # print(f"{bcolors.OKGREEN}\tSession Recording: {bcolors.WARNING}Disabled{bcolors.ENDC}") + # if config_allowed_settings.get('typescriptRecording'): + # print(f"{bcolors.OKGREEN}\tTypescript Recording: {bcolors.OKBLUE}Enabled{bcolors.ENDC}") + # else: + # print(f"{bcolors.OKGREEN}\tTypescript Recording: {bcolors.WARNING}Disabled{bcolors.ENDC}") + + class WebRTCConnection: def __init__(self, params: KeeperParams, record_uid, gateway_uid, symmetric_key, print_ready_event: asyncio.Event, kill_server_event: asyncio.Event, @@ -209,10 +626,6 @@ def __init__(self, params: KeeperParams, record_uid, gateway_uid, symmetric_key, krelay_url = os.getenv(KRELAY_URL) if krelay_url: self.relay_url = krelay_url - else: - alt_krelay_url = os.getenv(ALT_KRELAY_URL) - if alt_krelay_url: - self.relay_url = alt_krelay_url self.logger.debug(f'Using relay server: {self.relay_url}') try: self.peer_ice_config() @@ -221,7 +634,24 @@ def __init__(self, params: KeeperParams, record_uid, gateway_uid, symmetric_key, except Exception as e: raise Exception(f'Error setting up WebRTC connection: {e}') - async def signal_channel(self, kind: str): + async def attempt_reconnect(self): + # backoff retry logic + if self.retry_count < self.max_retries: + await asyncio.sleep(self.retry_delay) # Wait before retrying + await self.ice_restart() + self.retry_count += 1 + self.retry_delay *= 2 # Double the delay for the next retry if needed + else: + self.logger.error('Maximum reconnection attempts reached, stopping retries.') + await self.close_webrtc_connection() + + async def ice_restart(self): + self.peer_ice_config(ice_restart=True) + self.setup_data_channel() + self.setup_event_handlers() + await self.signal_channel('reconnect', base64_nonce=bytes_to_base64(generate_random_bytes(MAIN_NONCE_LENGTH))) + + async def signal_channel(self, kind: str, base64_nonce: str): # make webRTC sdp offer try: @@ -253,35 +683,27 @@ async def signal_channel(self, kind: str): 'mysql', 'postgresql', 'sql-server', 'http'] <-- What type of conversation is this (REQUIRED) 'kind': ['start', 'disconnect'], <-- What command to run (REQUIRED) 'conversations': [List of conversations to disconnect], <-- (Only for kind = disconnect) + 'base64Nonce': base64Nonce, <-- Random nonce to prevent replay attacks (REQUIRED) 'data': { <-- All data is encrypted with symmetric key (REQUIRED) 'offer': encrypted_WebRTC_sdp_offer, <-- WebRTC SDP offer, base64 encoded } } ''' - # TODO: remove when reporting is deployed to krouter prod!!! - dev_router = os.getenv("USE_REPORTING_COMPATABILITY_ROUTER") - if dev_router: - gateway_message_type = pam_pb2.CMT_CONNECT - self.logger.warning("#" * 30 + f"Sending CMT_CONNECT message type. Sergey, this is good..." + "#" * 30) - else: - gateway_message_type = pam_pb2.CMT_GENERAL - - # TODO create objects for WebRTC inputs router_response = router_send_action_to_gateway( params=self.params, gateway_action=GatewayActionWebRTCSession( inputs={ "recordUid": self.record_uid, 'kind': kind, + 'base64Nonce': base64_nonce, 'conversationType': 'tunnel', - "data": encrypted_data + "data": encrypted_data, } ), - message_type=gateway_message_type, + message_type=pam_pb2.CMT_CONNECT, is_streaming=False, - destination_gateway_uid_str=self.gateway_uid, - gateway_timeout=30000 + gateway_timeout=GATEWAY_TIMEOUT ) if not router_response: self.kill_server_event.set() @@ -310,8 +732,9 @@ async def signal_channel(self, kind: str): except Exception as e: raise Exception(f'Error decrypting WebRTC answer from data: {data}\nError: {e}') try: - str_data = bytes_to_string(data).replace("'", '"') - data = json.loads(str_data) + if isinstance(data, bytes): + data = bytes_to_string(data).replace("'", '"') + data = json.loads(data) except Exception as e: raise Exception(f'Error loading WebRTC answer from data: {data}\nError: {e}') if not data.get('answer'): @@ -324,21 +747,24 @@ async def signal_channel(self, kind: str): self.logger.debug("starting private tunnel") - def peer_ice_config(self): + def peer_ice_config(self, ice_restart=False): + if ice_restart and self._pc: + asyncio.create_task(self._pc.close()) + self._pc = None response = router_get_relay_access_creds(params=self.params, expire_sec=60000000) if response is None: raise Exception("Error getting relay access credentials") - if hasattr(response, "time"): - self.time_diff = datetime.now() - datetime.fromtimestamp(response.time) + if hasattr(response, "serverTime"): + self.time_diff = datetime.now() - datetime.fromtimestamp(response.serverTime/1000) stun_url = f"stun:{self.relay_url}:3478" # Create an RTCIceServer instance for the STUN server stun_server = RTCIceServer(urls=stun_url) # Define the TURN server URL and credentials - turn_url = f"turn:{self.relay_url}" + turn_url_udp = f"turn:{self.relay_url}:3478" # Create an RTCIceServer instance for the TURN server with credentials - turn_server = RTCIceServer(urls=turn_url, username=response.username, credential=response.password) + turn_server_udp = RTCIceServer(urls=turn_url_udp, username=response.username, credential=response.password) # Create a new RTCConfiguration with both STUN and TURN servers - config = RTCConfiguration(iceServers=[stun_server, turn_server]) + config = RTCConfiguration(iceServers=[stun_server, turn_server_udp]) self._pc = RTCPeerConnection(config) @@ -380,7 +806,9 @@ def on_connection_state_change(self): if self._pc.connectionState == "connected": # Connection is established, you can now send/receive data pass - elif self._pc.connectionState in ["disconnected", "failed", "closed"]: + elif self._pc.connectionState in "disconnected failed".split(): + asyncio.create_task(self.attempt_reconnect()) + elif self._pc.connectionState == "closed": # Handle disconnection or failure here asyncio.get_event_loop().create_task(self.close_webrtc_connection()) pass @@ -520,12 +948,16 @@ def __init__(self, print_ready_event, # type: asyncio.Event logger=None, # type: logging.Logger connect_task=None, # type: asyncio.Task - kill_server_event=None # type: asyncio.Event + kill_server_event=None, # type: asyncio.Event + target_host=None, # type: Optional[str] + target_port=None # type: Optional[int] ): # type: (...) -> None self.closing = False self.to_local_task = None self._ping_attempt = 0 self.host = host + self.target_host = target_host + self.target_port = target_port self.server = None self.connection_no = 1 self.connections: Dict[int, ConnectionInfo] = {0: ConnectionInfo(None, None, 0, None, None, datetime.now())} @@ -545,23 +977,24 @@ def port(self): async def send_to_web_rtc(self, data): if self.pc.is_data_channel_open(): + sleep_count = 0 + while (self.pc.data_channel is not None and + self.pc.data_channel.bufferedAmount >= BUFFER_THRESHOLD and + not self.kill_server_event.is_set() and + self.pc.is_data_channel_open()): + self.logger.debug(f"{bcolors.WARNING}Buffered amount is too high ({sleep_count * 100}) " + f"{self.pc.data_channel.bufferedAmount}{bcolors.ENDC}") + await asyncio.sleep(sleep_count) + sleep_count += .01 + try: - sleep_count = 0 - while (self.pc.data_channel is not None and - self.pc.data_channel.bufferedAmount >= BUFFER_THRESHOLD and - not self.kill_server_event.is_set() and - self.pc.is_data_channel_open()): - self.logger.debug(f"{bcolors.WARNING}Buffered amount is too high ({sleep_count * 100}) " - f"{self.pc.data_channel.bufferedAmount}{bcolors.ENDC}") - await asyncio.sleep(sleep_count) - sleep_count += .01 self.pc.send_message(data) - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - except Exception as e: self.logger.error(f'Endpoint {self.pc.endpoint_name}: Error sending message: {e}') await asyncio.sleep(0.1) + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) + else: if self.print_ready_event.is_set(): self.logger.error(f'Endpoint {self.pc.endpoint_name}: Data channel is not open. Data not sent.') @@ -576,8 +1009,8 @@ async def send_control_message(self, message_no, data=None): # type: (ControlMe """ buffer = make_control_message(message_no, data) try: - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Sending Control command {message_no} len: {len(buffer)}' - f' to tunnel.') + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Sending Control command {message_no} ' + f'len: {len(buffer)} to tunnel.') self.connections[0].transfer_size += len(buffer) await self.send_to_web_rtc(buffer) except Exception as e: @@ -595,7 +1028,7 @@ def update_stats(self, connection_no, data_size, timestamp): if c: dt = datetime.fromtimestamp(timestamp/1000) c.receive_size += data_size - td = datetime.now() + self.pc.time_diff - dt + td = datetime.now() - self.pc.time_diff - dt # Convert timedelta to total milliseconds td_milliseconds = (td.days * 24 * 60 * 60 + td.seconds) * 1000 + td.microseconds / 1000 c.receive_latency_sum += td_milliseconds @@ -624,19 +1057,18 @@ def report_stats(self, connection_no: int): async def process_control_message(self, message_no, data): # type: (ControlMessage, Optional[bytes]) -> None if message_no == ControlMessage.CloseConnection: self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Received close connection request') - if data and len(data) > 0: - + if data and len(data) >= CONNECTION_NO_LENGTH: target_connection_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') reason = CloseConnectionReasons.Unknown - if len(data) >= CONNECTION_NO_LENGTH: + if len(data) > CONNECTION_NO_LENGTH: reason_int = int.from_bytes(data[CONNECTION_NO_LENGTH:], byteorder='big') reason = CloseConnectionReasons(reason_int) self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Closing Connection {target_connection_no}' + (f'Reason: {reason}' if reason else '')) else: self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Closing Connection {target_connection_no}') - if target_connection_no == 0: + self.logger.info(f'Endpoint {self.pc.endpoint_name}: Received close Tunnel connection request.') self.kill_server_event.set() else: self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Closing connection ' @@ -645,8 +1077,8 @@ async def process_control_message(self, message_no, data): # type: (ControlMess elif message_no == ControlMessage.Pong: self._ping_attempt = 0 self.is_connected = True - if len(data) >= 0: - con_no = bytes_to_int(data) + if len(data) >= CONNECTION_NO_LENGTH: + con_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') if con_no in self.connections: self.connections[con_no].message_counter = 0 self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Received pong request') @@ -675,38 +1107,49 @@ async def process_control_message(self, message_no, data): # type: (ControlMess self.connections[0].ping_time = None elif message_no == ControlMessage.Ping: - if len(data) >= 0: - con_no = bytes_to_int(data) + if len(data) >= CONNECTION_NO_LENGTH: + con_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') if con_no in self.connections: - await self.send_control_message(ControlMessage.Pong, int_to_bytes(con_no)) - if con_no == 0: - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Received ping request') - else: - if self.logger.level == logging.DEBUG: - self.report_stats(0) - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Received Ping for {con_no}') + await self.send_control_message(ControlMessage.Pong, int.to_bytes(con_no, CONNECTION_NO_LENGTH, + byteorder='big')) + if len(data[:CONNECTION_NO_LENGTH]) >= TIME_STAMP_LENGTH: + self.connections[con_no].transfer_latency_sum += int.from_bytes(data[CONNECTION_NO_LENGTH: + CONNECTION_NO_LENGTH + + TIME_STAMP_LENGTH], + byteorder='big') + self.connections[con_no].transfer_latency_count += 1 + + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Received Ping for {con_no}') + if con_no != 0 and self.logger.level == logging.DEBUG: + self.report_stats(0) if self.logger.level == logging.DEBUG: # print the stats self.report_stats(con_no) else: - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Connection {con_no} not found') + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Connection {con_no} for Ping not found') else: - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Connection not found') + self.logger.debug(f'Endpoint {self.endpoint_name}: Received Ping request') + await self.send_control_message(ControlMessage.Pong, int.to_bytes(0, CONNECTION_NO_LENGTH, + byteorder='big')) elif message_no == ControlMessage.ConnectionOpened: if len(data) >= CONNECTION_NO_LENGTH: - if len(data) > CONNECTION_NO_LENGTH: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Received invalid open connection message" - f" ({len(data)} bytes)") connection_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Starting reader for connection " f"{connection_no}") + # If it is a socks connection then we need to signal the client that the connection is open + if isinstance(self, SOCKS5Server): + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Socks Connection {connection_no} opened') + # Send a success response back to the client + response = b'\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00' + self.connections[connection_no].writer.write(response) + await self.connections[connection_no].writer.drain() try: self.connections[connection_no].to_tunnel_task = asyncio.create_task( self.forward_data_to_tunnel(connection_no)) # From current connection to WebRTC connection self.logger.debug( f"Endpoint {self.pc.endpoint_name}: Started reader for connection {connection_no}") except ConnectionNotFoundException as e: - self.logger.debug(f"Endpoint {self.vendpoint_name}: Connection {connection_no} not found: {e}") + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection {connection_no} not found: {e}") except Exception as e: self.logger.error(f"Endpoint {self.pc.endpoint_name}: Error in forwarding data task: {e}") else: @@ -729,115 +1172,114 @@ async def forward_data_to_local(self): Control Packets [CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + CONTROL_MESSAGE_NO_LENGTH + DATA] Data Packets [CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + DATA] """ - try: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding data to local...") - buff = b'' - while not self.kill_server_event.is_set(): - if self.pc.closed: - self.kill_server_event.set() - break - while len(buff) >= CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH: - connection_no = int.from_bytes(buff[:CONNECTION_NO_LENGTH], byteorder='big') - time_stamp = int.from_bytes( - buff[CONNECTION_NO_LENGTH:CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH], byteorder='big') - length = int.from_bytes( - buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH: - CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH], byteorder='big') - if len(buff) >= CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR): - if (buff[(CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length): - (CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR))] != - TERMINATOR): - self.logger.warning(f'Endpoint {self.pc.endpoint_name}: Invalid terminator') - # if we don't have a valid terminator then we don't know where the message ends or begins - self.kill_server_event.set() - break - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Buffer data received data') - send_data = (buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH: - CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length]) - self.update_stats(connection_no, len(send_data) + CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + - DATA_LENGTH, time_stamp) - buff = buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR):] - - if connection_no == 0: - # This is a control message - control_m = ControlMessage(int.from_bytes(send_data[:CONTROL_MESSAGE_NO_LENGTH], - byteorder='big')) - - send_data = send_data[CONTROL_MESSAGE_NO_LENGTH:] - - await self.process_control_message(control_m, send_data) - else: - if connection_no not in self.connections: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection not found: " - f"{connection_no}") - continue - - try: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding data to " - f"local for connection {connection_no} ({len(send_data)})") - self.connections[connection_no].writer.write(send_data) - await self.connections[connection_no].writer.drain() - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - except Exception as ex: - self.logger.error(f"Endpoint {self.pc.endpoint_name}: Error while forwarding " - f"data to local: {ex}") - - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - else: - self.logger.debug( - f"Endpoint {self.pc.endpoint_name}: Buffer is too short {len(buff)} need " - f"{CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR)}") - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding data to local...") + buff = b'' + while not self.kill_server_event.is_set(): + if self.pc.closed: + self.kill_server_event.set() + break + while len(buff) >= CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH: + connection_no = int.from_bytes(buff[:CONNECTION_NO_LENGTH], byteorder='big') + time_stamp = int.from_bytes( + buff[CONNECTION_NO_LENGTH:CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH], byteorder='big') + length = int.from_bytes( + buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH: + CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH], byteorder='big') + if len(buff) >= CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR): + if (buff[(CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length): + (CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR))] != + TERMINATOR): + self.logger.warning(f'Endpoint {self.pc.endpoint_name}: Invalid terminator') + # if we don't have a valid terminator then we don't know where the message ends or begins + self.kill_server_event.set() break - if self.kill_server_event.is_set(): - break - try: - data = await asyncio.wait_for(self.pc.web_rtc_queue.get(), READ_TIMEOUT) - except asyncio.TimeoutError as et: - if self._ping_attempt > 3: - if self.is_connected: - self.kill_server_event.set() - raise et - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Tunnel reader timed out') - if self.is_connected and self.pc.is_data_channel_open(): - self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Send ping request') - await self.send_control_message(ControlMessage.Ping, int_to_bytes(0)) - - if self.logger.level == logging.DEBUG: - # print the stats - for c in self.connections.keys(): - self.report_stats(c) - self._ping_attempt += 1 - continue - self.pc.web_rtc_queue.task_done() - if not data or not self.is_connected: - self.logger.info(f"Endpoint {self.pc.endpoint_name}: Exiting forward data to local") - break - elif len(data) == 0: - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - continue - elif isinstance(data, bytes): - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Got data from WebRTC connection " - f"{len(data)} bytes") - buff += data + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Buffer data received data') + send_data = (buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH: + CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length]) + self.update_stats(connection_no, len(send_data) + CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + + DATA_LENGTH, time_stamp) + buff = buff[CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR):] + + if connection_no == 0: + # This is a control message + control_m = ControlMessage(int.from_bytes(send_data[:CONTROL_MESSAGE_NO_LENGTH], + byteorder='big')) + + send_data = send_data[CONTROL_MESSAGE_NO_LENGTH:] + + await self.process_control_message(control_m, send_data) + else: + if connection_no not in self.connections: + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection not found: " + f"{connection_no}") + continue + + try: + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding data to " + f"local for connection {connection_no} ({len(send_data)})") + self.connections[connection_no].writer.write(send_data) + await self.connections[connection_no].writer.drain() + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) + except Exception as ex: + self.logger.error(f"Endpoint {self.pc.endpoint_name}: Error while forwarding " + f"data to local: {ex}") + + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) else: + self.logger.debug( + f"Endpoint {self.pc.endpoint_name}: Buffer is too short {len(buff)} need " + f"{CONNECTION_NO_LENGTH + TIME_STAMP_LENGTH + DATA_LENGTH + length + len(TERMINATOR)}") # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) + break + if self.kill_server_event.is_set(): + break + try: + data = await asyncio.wait_for(self.pc.web_rtc_queue.get(), READ_TIMEOUT) + except asyncio.TimeoutError as et: + if self._ping_attempt > 3: + if self.is_connected: + self.kill_server_event.set() + raise et + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Tunnel reader timed out') + if self.is_connected and self.pc.is_data_channel_open(): + self.logger.debug(f'Endpoint {self.pc.endpoint_name}: Send ping request') + + buffer = int.to_bytes(0, CONNECTION_NO_LENGTH, byteorder='big') + if self.connections[0].receive_latency_count > 0: + receive_latency_average = int(self.connections[0].receive_latency_sum / + self.connections[0].receive_latency_count) + buffer += int.to_bytes(receive_latency_average, TIME_STAMP_LENGTH, byteorder='big') + await self.send_control_message(ControlMessage.Ping, buffer) - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Exiting forward data successfully.") - except asyncio.CancelledError: - pass + if self.logger.level == logging.DEBUG: + # print the stats + for c in self.connections.keys(): + self.report_stats(c) + self._ping_attempt += 1 + continue + self.pc.web_rtc_queue.task_done() + if not data or not self.is_connected: + self.logger.info(f"Endpoint {self.pc.endpoint_name}: Exiting forward data to local") + break + elif len(data) == 0: + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) + continue + elif isinstance(data, bytes): + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Got data from WebRTC connection " + f"{len(data)} bytes") + buff += data + else: + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) - except Exception as ex: - self.logger.error(f"Endpoint {self.pc.endpoint_name}: Error while forwarding data: {ex}") + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Exiting forward data successfully.") - finally: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Closing tunnel") - await self.stop_server(CloseConnectionReasons.Normal) + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Closing tunnel") + await self.stop_server(CloseConnectionReasons.Normal) async def start_reader(self): # type: () -> None """ @@ -849,7 +1291,7 @@ async def start_reader(self): # type: () -> None self.to_local_task = asyncio.create_task(self.forward_data_to_local()) # Send hello world open connection message - await self.send_control_message(ControlMessage.Ping, int_to_bytes(0)) + await self.send_control_message(ControlMessage.Ping, int.to_bytes(0, CONNECTION_NO_LENGTH, byteorder='big')) self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Sent ping message to WebRTC connection") except asyncio.CancelledError: pass @@ -866,71 +1308,75 @@ async def forward_data_to_tunnel(self, con_no): """ Forward data from the given connection to the WebRTC connection """ - try: - while not self.kill_server_event.is_set(): - c = self.connections.get(con_no) - if c is None or not self.is_connected: - break - try: - data = await c.reader.read(BUFFER_TRUNCATION_THRESHOLD) - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding {len(data)} " - f"bytes to tunnel for connection {con_no}") - if not data: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection {con_no} no data") - break - if isinstance(data, bytes): - if c.reader.at_eof() and len(data) == 0: - if not self.eof_sent: - await self.send_control_message(ControlMessage.SendEOF, - int_to_bytes(con_no, CONNECTION_NO_LENGTH)) - self.eof_sent = True - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - continue - else: - self.eof_sent = False - buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') - # Add timestamp - timestamp_ms = int(datetime.now().timestamp() * 1000) - buffer += int.to_bytes(timestamp_ms, TIME_STAMP_LENGTH, byteorder='big') - buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR - self.connections[con_no].transfer_size += len(buffer) - await self.send_to_web_rtc(buffer) - - self.logger.debug( - f'Endpoint {self.pc.endpoint_name}: buffer size: {self.pc.data_channel.bufferedAmount}' + - f', time since start: {datetime.now() - c.start_time}') - - c.message_counter += 1 - if (c.message_counter >= MESSAGE_MAX and - self.pc.data_channel.bufferedAmount > BUFFER_TRUNCATION_THRESHOLD): - c.ping_time = time.perf_counter() - await self.send_control_message(ControlMessage.Ping, int_to_bytes(con_no)) - self._ping_attempt += 1 - wait_count = 0 - while c.message_counter >= MESSAGE_MAX: - await asyncio.sleep(wait_count) - wait_count += .1 - elif (c.message_counter >= MESSAGE_MAX and - self.pc.data_channel.bufferedAmount <= BUFFER_TRUNCATION_THRESHOLD): - c.message_counter = 0 + while not self.kill_server_event.is_set(): + c = self.connections.get(con_no) + if c is None or not self.is_connected: + break + try: + data = await c.reader.read(BUFFER_TRUNCATION_THRESHOLD) + except Exception as e: + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection '{con_no}' read failed: {e}") + break + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Forwarding {len(data)} " + f"bytes to tunnel for connection {con_no}") + if not data: + self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection {con_no} no data") + break + if isinstance(data, bytes): + if c.reader.at_eof() and len(data) == 0: + if not self.eof_sent: + await self.send_control_message(ControlMessage.SendEOF, + int.to_bytes(con_no, CONNECTION_NO_LENGTH, + byteorder='big')) + self.eof_sent = True + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) + continue + else: + self.eof_sent = False + buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') + # Add timestamp + timestamp_ms = int(datetime.now().timestamp() * 1000) + buffer += int.to_bytes(timestamp_ms, TIME_STAMP_LENGTH, byteorder='big') + buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR + self.connections[con_no].transfer_size += len(buffer) + await self.send_to_web_rtc(buffer) - else: - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - except Exception as e: - self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection '{con_no}' read failed: {e}") - break - except Exception as e: - self.logger.error(f"Endpoint {self.pc.endpoint_name}: Error while forwarding data in connection " - f"{con_no}: {e}") + self.logger.debug( + f'Endpoint {self.pc.endpoint_name}: buffer size: {self.pc.data_channel.bufferedAmount}' + f', time since start: {datetime.now() - c.start_time}') + + c.message_counter += 1 + if (c.message_counter >= MESSAGE_MAX and + self.pc.data_channel.bufferedAmount > BUFFER_TRUNCATION_THRESHOLD): + c.ping_time = time.perf_counter() + + ping_buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') + if self.connections[0].receive_latency_count > 0: + receive_latency_average = int(self.connections[0].receive_latency_sum / + self.connections[0].receive_latency_count) + ping_buffer += int.to_bytes(receive_latency_average, TIME_STAMP_LENGTH, + byteorder='big') + await self.send_control_message(ControlMessage.Ping, ping_buffer) + self._ping_attempt += 1 + wait_count = 0 + while c.message_counter >= MESSAGE_MAX: + await asyncio.sleep(wait_count) + wait_count += .1 + elif (c.message_counter >= MESSAGE_MAX and + self.pc.data_channel.bufferedAmount <= BUFFER_TRUNCATION_THRESHOLD): + c.message_counter = 0 + + else: + # Yield control back to the event loop for other tasks to execute + await asyncio.sleep(0) if con_no not in self.connections: raise ConnectionNotFoundException(f"Connection {con_no} not found") # Send close connection message with con_no buff = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') - buff += int_to_bytes(CloseConnectionReasons.Normal.value) + buff += int.to_bytes(CloseConnectionReasons.Normal.value, CLOSE_CONNECTION_REASON_LENGTH, byteorder='big') await self.send_control_message(ControlMessage.CloseConnection, buff) await self.close_connection(con_no, CloseConnectionReasons.Normal) @@ -1027,7 +1473,7 @@ async def close_connection(self, connection_no, reason: CloseConnectionReasons): self.report_stats(connection_no) try: buffer = int.to_bytes(connection_no, CONNECTION_NO_LENGTH, byteorder='big') - buffer += int_to_bytes(reason.value) + buffer += int.to_bytes(reason.value, CLOSE_CONNECTION_REASON_LENGTH, byteorder='big') await self.send_control_message(ControlMessage.CloseConnection, buffer) except Exception as ex: self.logger.warning(f'Endpoint {self.pc.endpoint_name}: hit exception sending Close connection {ex}') @@ -1071,3 +1517,188 @@ async def close_connection(self, connection_no, reason: CloseConnectionReasons): self.logger.warning(f'Endpoint {self.pc.endpoint_name}: hit exception deleting connection {ex}') else: self.logger.debug(f"Endpoint {self.pc.endpoint_name}: Connection {connection_no} not found") + + +class SOCKS5Server(TunnelEntrance): + def __init__(self, + host, # type: str + port, # type: int + pc, # type: WebRTCConnection + print_ready_event, # type: asyncio.Event + logger=None, # type: logging.Logger + connect_task=None, # type: asyncio.Task + kill_server_event=None, # type: asyncio.Event + target_host=None, # type: Optional[str] + target_port=None # type: Optional[int] + ): # type: (...) -> None + super().__init__(host, port, pc, print_ready_event, logger, connect_task, kill_server_event, target_host, + target_port) + # Credentials for authentication + # self.valid_username = os.getenv('SOCKS5_USERNAME', 'defaultuser') + # self.valid_password = os.getenv('SOCKS5_PASSWORD', 'defaultpass') + # if self.valid_username == 'defaultuser' or self.valid_password == 'defaultpass': + # self.logger.warning("Default SOCKS5 credentials are being used. " + # "Please set SOCKS5_USERNAME and SOCKS5_PASSWORD environment variables.") + + # async def username_password_authenticate(self, reader, writer): + # # Username/Password Authentication (RFC 1929) + # try: + # auth_version_bytes = await reader.readexactly(1) + # auth_version = ord(auth_version_bytes) + # if auth_version != 1: # Should be 0x01 for username/password auth + # return False + # + # username_length_bytes = await reader.readexactly(1) + # username_length = ord(username_length_bytes) + # username = await reader.readexactly(username_length) + # username = username.decode() + # + # password_length_bytes = await reader.readexactly(1) + # password_length = ord(password_length_bytes) + # password = await reader.readexactly(password_length) + # password = password.decode() + # except asyncio.IncompleteReadError: + # # Handle the case where the client disconnects or sends incomplete data + # return False + # + # # Verify username and password + # if username == self.valid_username and password == self.valid_password: + # writer.write(b'\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00') # Authentication succeeded + # await writer.drain() + # return True + # else: + # writer.write(b'\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00') # Authentication failed + # await writer.drain() + # return False + + async def handle_connection(self, reader, writer): # type: (asyncio.StreamReader, asyncio.StreamWriter) -> None + """ + This is called when a client connects to the local port starting a new session. + This extends the base handle_connection method to handle SOCKS5 connections. + """ + async def quick_close(reason: bytes): + writer.write(reason) # Network unreachable + writer.close() + await writer.wait_closed() + + connection_no = self.connection_no + self.connection_no += 1 + self.connections[connection_no] = ConnectionInfo(reader, writer, 0, None, None, datetime.now()) + + # Only allow connections from localhost + client_host, client_port = writer.get_extra_info('peername') + if client_host != '127.0.0.1': + self.logger.warning(f"Connection from {client_host}:{client_port} rejected") + await quick_close(b'\x05\x02\x00\x01\x00\x00\x00\x00\x00\x00') # Connection not allowed + return + + # Initial greeting and authentication method negotiation + # SOCKS5, 2 authentication methods, No Auth and Username/Password + # supported_methods = [0x00, 0x02] + supported_methods = [0x00] + # Wait for the client's authentication method request + client_greeting = await reader.read(2) + socks_version, n_methods = client_greeting + + if socks_version not in [0x05, 0x04]: # SOCKS5 or SOCKS4 + self.logger.error("Invalid SOCKS version") + await quick_close(b'\x05\x01\x00\x01\x00\x00\x00\x00\x00\x00') # Unsupported version + return + + method_ids = await reader.readexactly(n_methods) + # Decide which method to use (prefer No Auth if available) + if 0x00 in method_ids and 0x00 in supported_methods: + selected_method = 0x00 # No Authentication Required + # elif 0x02 in method_ids and 0x02 in supported_methods: + # selected_method = 0x02 # Username/Password + else: + selected_method = 0xff # No acceptable methods + + # Send the selected method back to the client + writer.write(struct.pack("!BB", socks_version, selected_method)) + await writer.drain() + + # Proceed based on the selected method + if selected_method == 0x00: + # No further authentication needed + pass + # elif selected_method == 0x02: + # # Perform username/password authentication + # auth_success = await self.username_password_authenticate(reader, writer) + # if not auth_success: + # self.logger.error("Authentication failed") + # writer.close() + # await writer.wait_closed() + # return + else: + # No acceptable method found, close the connection + self.logger.error("No acceptable authentication method found") + await quick_close(b'\x05\xFF\x00\x01\x00\x00\x00\x00\x00\x00') # No acceptable methods + return + + # Read the connection request + data = await reader.read(4) + if len(data) != 4: + self.logger.error("Invalid connection request") + await quick_close(b'\x05\x02\x00\x01\x00\x00\x00\x00\x00\x00') # Command not supported + return + + version, cmd, reserved, address_type = struct.unpack('!BBBB', data) + + if cmd != 1: # 1 for CONNECT + self.logger.error("Unsupported command") + await quick_close(b'\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00') # Command not supported + return + + # # Pseudo-code for handling a BIND command + # if cmd == 2: # BIND + # # bind_address and bind_port are from the client request + # external_socket = await bind_and_listen(bind_address, bind_port) + # server_reply_address, server_reply_port = external_socket.getsockname() + # # Send server's chosen address and port back to the client + # send_bind_reply_to_client(server_reply_address, server_reply_port) + # # Wait for an external connection + # external_conn, external_addr = await external_socket.accept() + # # Notify client of the external connection details + # notify_client_of_external_connection(external_addr) + # # Proceed to relay data between client and external connection + + # # Pseudo-code for handling a UDP ASSOCIATE command + # if cmd == 3: # UDP ASSOCIATE + # # client's address and port are what? + # udp_socket = await allocate_udp_port() + # socks_server_udp_address, socks_server_udp_port = udp_socket.getsockname() + # # Send SOCKS server's UDP address and port back to the client + # send_udp_associate_reply_to_client(socks_server_udp_address, socks_server_udp_port) + # # Listen for UDP datagrams from the client and relay them accordingly + # while True: + # data, addr = await udp_socket.recvfrom() + # if is_datagram_for_client(data): + # relay_datagram_to_final_destination(data, addr) + # else: + # relay_datagram_to_client(data, addr) + + if address_type == 1: # IPv4 + address = await reader.readexactly(4) + tunnel_host = '.'.join(str(byte) for byte in address) + elif address_type == 3: # Domain name + domain_length = ord(await reader.readexactly(1)) + tunnel_host = await reader.readexactly(domain_length) + tunnel_host = tunnel_host.decode() + elif address_type == 4: # IPv6 + address = await reader.readexactly(16) + tunnel_host = ':'.join(str(byte) for byte in address) + else: + self.logger.error("Unsupported address type") + await quick_close(b'\x05\x08\x00\x01\x00\x00\x00\x00\x00\x00') # Address type not supported + return + + tunnel_port = int.from_bytes(await reader.readexactly(2), 'big') + + # Send open connection message with con_no. this is required to be sent to start the connection + data = int.to_bytes(connection_no, CONNECTION_NO_LENGTH, byteorder='big') + tunnel_host_bytes = tunnel_host.encode() + data += int.to_bytes(len(tunnel_host_bytes), CONNECTION_NO_LENGTH, byteorder='big') + data += tunnel_host_bytes + data += int.to_bytes(tunnel_port, PORT_LENGTH, byteorder='big') + await self.send_control_message(ControlMessage.OpenConnection, data) diff --git a/keepercommander/proto/README.md b/keepercommander/proto/README.md new file mode 100644 index 000000000..9261ec482 --- /dev/null +++ b/keepercommander/proto/README.md @@ -0,0 +1,42 @@ +# Protoc + +These files were generated with protoc 3.19.4. + +https://github.com/protocolbuffers/protobuf/releases/tag/v3.19.4 + +On macOS, you will need to approve the running of `protoc`. +This can be done by running `protoc`, approving the dialog box, then going + to the **Privacy & Security** tab in the **Settings**. +In the **Security** section, allow `protoc` to run. +The next time you run `protoc`, you'll still get a popup, but it will allow + the application to run. + +## Generate a Python file + +Change into the directory, of the repo, that contains the .proto files. +You need to be in that directory because the `.proto` file + may include other `.proto` files. + +```shell +/path/to/protoc-3.19.4-osx-x86_64/bin/protoc --python_out=.. FOO.proto +``` +In the prior directory, `FOO_pb2.py` will be created. +Move the file into the `keepercommander/proto` directory. + +## Edit the file + +Ignore `# Generated by the protocol buffer compiler. DO NOT EDIT!` :) + +You'll need change the import files to relative style imports. For example, change .... + +```python +import enterprise_pb2 as enterprise_pb2 +import record_pb2 as record_pb2 +``` + +to + +```python +from . import enterprise_pb2 as enterprise_pb2 +from . import record_pb2 as record_pb2 +``` \ No newline at end of file diff --git a/keepercommander/proto/connect_pb2.py b/keepercommander/proto/connect_pb2.py index 0bd02f050..c5ee3a3af 100644 --- a/keepercommander/proto/connect_pb2.py +++ b/keepercommander/proto/connect_pb2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: connect.proto +# source: protobuf/connect.proto """Generated protocol buffer code.""" from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor @@ -15,7 +15,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rconnect.proto\x12\x0eRouterInternal\"J\n\x10RelayAccessCreds\x12\x10\n\x08username\x18\x01 \x01(\t\x12\x10\n\x08password\x18\x02 \x01(\t\x12\x12\n\nserverTime\x18\x03 \x01(\x03\"\xb8\x01\n\x0cPAMRecording\x12\x15\n\rconnectionUid\x18\x01 \x01(\x0c\x12\x37\n\rrecordingType\x18\x02 \x01(\x0e\x32 .RouterInternal.PAMRecordingType\x12\x11\n\trecordUid\x18\x03 \x01(\x0c\x12\x10\n\x08userName\x18\x04 \x01(\t\x12\x11\n\tstartedOn\x18\x05 \x01(\x03\x12\x0e\n\x06length\x18\x06 \x01(\x05\x12\x10\n\x08\x66ileSize\x18\x07 \x01(\x03\"I\n\x15PAMRecordingsResponse\x12\x30\n\nrecordings\x18\x01 \x03(\x0b\x32\x1c.RouterInternal.PAMRecording\"S\n\x14\x43ontrollerLogRequest\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x10\n\x08\x66romTime\x18\x02 \x01(\x03\x12\x12\n\nmaxEntries\x18\x03 \x01(\x05\"_\n\x12\x43ontrollerLogEntry\x12\x0c\n\x04time\x18\x01 \x01(\x03\x12*\n\x08logLevel\x18\x02 \x01(\x0e\x32\x18.RouterInternal.LogLevel\x12\x0f\n\x07message\x18\x03 \x01(\x0c\"Z\n\x15\x43ontrollerLogResponse\x12\x30\n\x04logs\x18\x01 \x03(\x0b\x32\".RouterInternal.ControllerLogEntry\x12\x0f\n\x07hasMore\x18\x02 \x01(\x08*E\n\x10PAMRecordingType\x12\x0f\n\x0bPRT_SESSION\x10\x00\x12\x12\n\x0ePRT_TYPESCRIPT\x10\x01\x12\x0c\n\x08PRT_TIME\x10\x02*\\\n\x08LogLevel\x12\x0c\n\x08LL_TRACE\x10\x00\x12\x0c\n\x08LL_DEBUG\x10\x01\x12\x0b\n\x07LL_INFO\x10\x02\x12\x0b\n\x07LL_WARN\x10\x03\x12\x0c\n\x08LL_ERROR\x10\x04\x12\x0c\n\x08LL_FATAL\x10\x05\x42-\n\"com.keepersecurity.pamRouter.protoB\x07\x43onnectb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16protobuf/connect.proto\x12\x0eRouterInternal\"J\n\x10RelayAccessCreds\x12\x10\n\x08username\x18\x01 \x01(\t\x12\x10\n\x08password\x18\x02 \x01(\t\x12\x12\n\nserverTime\x18\x03 \x01(\x03\"\xb8\x01\n\x0cPAMRecording\x12\x15\n\rconnectionUid\x18\x01 \x01(\x0c\x12\x37\n\rrecordingType\x18\x02 \x01(\x0e\x32 .RouterInternal.PAMRecordingType\x12\x11\n\trecordUid\x18\x03 \x01(\x0c\x12\x10\n\x08userName\x18\x04 \x01(\t\x12\x11\n\tstartedOn\x18\x05 \x01(\x03\x12\x0e\n\x06length\x18\x06 \x01(\x05\x12\x10\n\x08\x66ileSize\x18\x07 \x01(\x03\"I\n\x15PAMRecordingsResponse\x12\x30\n\nrecordings\x18\x01 \x03(\x0b\x32\x1c.RouterInternal.PAMRecording\"S\n\x14\x43ontrollerLogRequest\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x10\n\x08\x66romTime\x18\x02 \x01(\x03\x12\x12\n\nmaxEntries\x18\x03 \x01(\x05\"_\n\x12\x43ontrollerLogEntry\x12\x0c\n\x04time\x18\x01 \x01(\x03\x12*\n\x08logLevel\x18\x02 \x01(\x0e\x32\x18.RouterInternal.LogLevel\x12\x0f\n\x07message\x18\x03 \x01(\x0c\"Z\n\x15\x43ontrollerLogResponse\x12\x30\n\x04logs\x18\x01 \x03(\x0b\x32\".RouterInternal.ControllerLogEntry\x12\x0f\n\x07hasMore\x18\x02 \x01(\x08*E\n\x10PAMRecordingType\x12\x0f\n\x0bPRT_SESSION\x10\x00\x12\x12\n\x0ePRT_TYPESCRIPT\x10\x01\x12\x0c\n\x08PRT_TIME\x10\x02*\\\n\x08LogLevel\x12\x0c\n\x08LL_TRACE\x10\x00\x12\x0c\n\x08LL_DEBUG\x10\x01\x12\x0b\n\x07LL_INFO\x10\x02\x12\x0b\n\x07LL_WARN\x10\x03\x12\x0c\n\x08LL_ERROR\x10\x04\x12\x0c\n\x08LL_FATAL\x10\x05\x42-\n\"com.keepersecurity.pamRouter.protoB\x07\x43onnectb\x06proto3') _PAMRECORDINGTYPE = DESCRIPTOR.enum_types_by_name['PAMRecordingType'] PAMRecordingType = enum_type_wrapper.EnumTypeWrapper(_PAMRECORDINGTYPE) @@ -40,42 +40,42 @@ _CONTROLLERLOGRESPONSE = DESCRIPTOR.message_types_by_name['ControllerLogResponse'] RelayAccessCreds = _reflection.GeneratedProtocolMessageType('RelayAccessCreds', (_message.Message,), { 'DESCRIPTOR' : _RELAYACCESSCREDS, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.RelayAccessCreds) }) _sym_db.RegisterMessage(RelayAccessCreds) PAMRecording = _reflection.GeneratedProtocolMessageType('PAMRecording', (_message.Message,), { 'DESCRIPTOR' : _PAMRECORDING, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.PAMRecording) }) _sym_db.RegisterMessage(PAMRecording) PAMRecordingsResponse = _reflection.GeneratedProtocolMessageType('PAMRecordingsResponse', (_message.Message,), { 'DESCRIPTOR' : _PAMRECORDINGSRESPONSE, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.PAMRecordingsResponse) }) _sym_db.RegisterMessage(PAMRecordingsResponse) ControllerLogRequest = _reflection.GeneratedProtocolMessageType('ControllerLogRequest', (_message.Message,), { 'DESCRIPTOR' : _CONTROLLERLOGREQUEST, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.ControllerLogRequest) }) _sym_db.RegisterMessage(ControllerLogRequest) ControllerLogEntry = _reflection.GeneratedProtocolMessageType('ControllerLogEntry', (_message.Message,), { 'DESCRIPTOR' : _CONTROLLERLOGENTRY, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.ControllerLogEntry) }) _sym_db.RegisterMessage(ControllerLogEntry) ControllerLogResponse = _reflection.GeneratedProtocolMessageType('ControllerLogResponse', (_message.Message,), { 'DESCRIPTOR' : _CONTROLLERLOGRESPONSE, - '__module__' : 'connect_pb2' + '__module__' : 'protobuf.connect_pb2' # @@protoc_insertion_point(class_scope:RouterInternal.ControllerLogResponse) }) _sym_db.RegisterMessage(ControllerLogResponse) @@ -84,20 +84,20 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\"com.keepersecurity.pamRouter.protoB\007Connect' - _PAMRECORDINGTYPE._serialized_start=645 - _PAMRECORDINGTYPE._serialized_end=714 - _LOGLEVEL._serialized_start=716 - _LOGLEVEL._serialized_end=808 - _RELAYACCESSCREDS._serialized_start=33 - _RELAYACCESSCREDS._serialized_end=107 - _PAMRECORDING._serialized_start=110 - _PAMRECORDING._serialized_end=294 - _PAMRECORDINGSRESPONSE._serialized_start=296 - _PAMRECORDINGSRESPONSE._serialized_end=369 - _CONTROLLERLOGREQUEST._serialized_start=371 - _CONTROLLERLOGREQUEST._serialized_end=454 - _CONTROLLERLOGENTRY._serialized_start=456 - _CONTROLLERLOGENTRY._serialized_end=551 - _CONTROLLERLOGRESPONSE._serialized_start=553 - _CONTROLLERLOGRESPONSE._serialized_end=643 + _PAMRECORDINGTYPE._serialized_start=654 + _PAMRECORDINGTYPE._serialized_end=723 + _LOGLEVEL._serialized_start=725 + _LOGLEVEL._serialized_end=817 + _RELAYACCESSCREDS._serialized_start=42 + _RELAYACCESSCREDS._serialized_end=116 + _PAMRECORDING._serialized_start=119 + _PAMRECORDING._serialized_end=303 + _PAMRECORDINGSRESPONSE._serialized_start=305 + _PAMRECORDINGSRESPONSE._serialized_end=378 + _CONTROLLERLOGREQUEST._serialized_start=380 + _CONTROLLERLOGREQUEST._serialized_end=463 + _CONTROLLERLOGENTRY._serialized_start=465 + _CONTROLLERLOGENTRY._serialized_end=560 + _CONTROLLERLOGRESPONSE._serialized_start=562 + _CONTROLLERLOGRESPONSE._serialized_end=652 # @@protoc_insertion_point(module_scope) diff --git a/keepercommander/proto/pam_pb2.py b/keepercommander/proto/pam_pb2.py index c661072ac..584b4798d 100644 --- a/keepercommander/proto/pam_pb2.py +++ b/keepercommander/proto/pam_pb2.py @@ -17,14 +17,26 @@ from . import record_pb2 as record__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tpam.proto\x12\x03PAM\x1a\x10\x65nterprise.proto\x1a\x0crecord.proto\"\x83\x01\n\x13PAMRotationSchedule\x12\x11\n\trecordUid\x18\x01 \x01(\x0c\x12\x18\n\x10\x63onfigurationUid\x18\x02 \x01(\x0c\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x14\n\x0cscheduleData\x18\x04 \x01(\t\x12\x12\n\nnoSchedule\x18\x05 \x01(\x08\"K\n\x1cPAMRotationSchedulesResponse\x12+\n\tschedules\x18\x01 \x03(\x0b\x32\x18.PAM.PAMRotationSchedule\"e\n\x13PAMOnlineController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x13\n\x0b\x63onnectedOn\x18\x02 \x01(\x03\x12\x11\n\tipAddress\x18\x03 \x01(\t\x12\x0f\n\x07version\x18\x04 \x01(\t\"Y\n\x14PAMOnlineControllers\x12\x12\n\ndeprecated\x18\x01 \x03(\x0c\x12-\n\x0b\x63ontrollers\x18\x02 \x03(\x0b\x32\x18.PAM.PAMOnlineController\"9\n\x10PAMRotateRequest\x12\x12\n\nrequestUid\x18\x01 \x01(\x0c\x12\x11\n\trecordUid\x18\x02 \x01(\x0c\"A\n\x16PAMControllersResponse\x12\'\n\x0b\x63ontrollers\x18\x01 \x03(\x0b\x32\x12.PAM.PAMController\"=\n\x13PAMRemoveController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x0f\n\x07message\x18\x02 \x01(\t\"L\n\x1bPAMRemoveControllerResponse\x12-\n\x0b\x63ontrollers\x18\x01 \x03(\x0b\x32\x18.PAM.PAMRemoveController\"=\n\x10PAMModifyRequest\x12)\n\noperations\x18\x01 \x03(\x0b\x32\x15.PAM.PAMDataOperation\"\x98\x01\n\x10PAMDataOperation\x12,\n\roperationType\x18\x01 \x01(\x0e\x32\x15.PAM.PAMOperationType\x12\x30\n\rconfiguration\x18\x02 \x01(\x0b\x32\x19.PAM.PAMConfigurationData\x12$\n\x07\x65lement\x18\x03 \x01(\x0b\x32\x13.PAM.PAMElementData\"e\n\x14PAMConfigurationData\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x0e\n\x06nodeId\x18\x02 \x01(\x03\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"E\n\x0ePAMElementData\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12\x11\n\tparentUid\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"m\n\x19PAMElementOperationResult\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12+\n\x06result\x18\x02 \x01(\x0e\x32\x1b.PAM.PAMOperationResultType\x12\x0f\n\x07message\x18\x03 \x01(\t\"B\n\x0fPAMModifyResult\x12/\n\x07results\x18\x01 \x03(\x0b\x32\x1e.PAM.PAMElementOperationResult\"x\n\nPAMElement\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\x0f\n\x07\x63reated\x18\x03 \x01(\x03\x12\x14\n\x0clastModified\x18\x04 \x01(\x03\x12!\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x0f.PAM.PAMElement\"#\n\x14PAMGenericUidRequest\x12\x0b\n\x03uid\x18\x01 \x01(\x0c\"%\n\x15PAMGenericUidsRequest\x12\x0c\n\x04uids\x18\x01 \x03(\x0c\"\xab\x01\n\x10PAMConfiguration\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x0e\n\x06nodeId\x18\x02 \x01(\x03\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\x0f\n\x07\x63reated\x18\x05 \x01(\x03\x12\x14\n\x0clastModified\x18\x06 \x01(\x03\x12!\n\x08\x63hildren\x18\x07 \x03(\x0b\x32\x0f.PAM.PAMElement\"B\n\x11PAMConfigurations\x12-\n\x0e\x63onfigurations\x18\x01 \x03(\x0b\x32\x15.PAM.PAMConfiguration\"\xff\x01\n\rPAMController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x16\n\x0e\x63ontrollerName\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65viceToken\x18\x03 \x01(\t\x12\x12\n\ndeviceName\x18\x04 \x01(\t\x12\x0e\n\x06nodeId\x18\x05 \x01(\x03\x12\x0f\n\x07\x63reated\x18\x06 \x01(\x03\x12\x14\n\x0clastModified\x18\x07 \x01(\x03\x12\x16\n\x0e\x61pplicationUid\x18\x08 \x01(\x0c\x12\x30\n\rappClientType\x18\t \x01(\x0e\x32\x19.Enterprise.AppClientType\x12\x15\n\risInitialized\x18\n \x01(\x08\"%\n\x12\x43ontrollerResponse\x12\x0f\n\x07payload\x18\x01 \x01(\t\"M\n\x1aPAMConfigurationController\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x15\n\rcontrollerUid\x18\x02 \x01(\x0c\"\xa3\x01\n\x17\x43onfigurationAddRequest\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x11\n\trecordKey\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12(\n\x0brecordLinks\x18\x04 \x03(\x0b\x32\x13.Records.RecordLink\x12#\n\x05\x61udit\x18\x05 \x01(\x0b\x32\x14.Records.RecordAudit\"6\n\x10RelayAccessCreds\x12\x10\n\x08username\x18\x01 \x01(\t\x12\x10\n\x08password\x18\x02 \x01(\t*@\n\x10PAMOperationType\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\x0b\n\x07REPLACE\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*p\n\x16PAMOperationResultType\x12\x0f\n\x0bPOT_SUCCESS\x10\x00\x12\x15\n\x11POT_UNKNOWN_ERROR\x10\x01\x12\x16\n\x12POT_ALREADY_EXISTS\x10\x02\x12\x16\n\x12POT_DOES_NOT_EXIST\x10\x03*H\n\x15\x43ontrollerMessageType\x12\x0f\n\x0b\x43MT_GENERAL\x10\x00\x12\x0e\n\nCMT_ROTATE\x10\x01\x12\x0e\n\nCMT_STREAM\x10\x02\x42\x1f\n\x18\x63om.keepersecurity.protoB\x03PAMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tpam.proto\x12\x03PAM\x1a\x10\x65nterprise.proto\x1a\x0crecord.proto\"\x83\x01\n\x13PAMRotationSchedule\x12\x11\n\trecordUid\x18\x01 \x01(\x0c\x12\x18\n\x10\x63onfigurationUid\x18\x02 \x01(\x0c\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x14\n\x0cscheduleData\x18\x04 \x01(\t\x12\x12\n\nnoSchedule\x18\x05 \x01(\x08\"K\n\x1cPAMRotationSchedulesResponse\x12+\n\tschedules\x18\x01 \x03(\x0b\x32\x18.PAM.PAMRotationSchedule\"\x94\x01\n\x13PAMOnlineController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x13\n\x0b\x63onnectedOn\x18\x02 \x01(\x03\x12\x11\n\tipAddress\x18\x03 \x01(\t\x12\x0f\n\x07version\x18\x04 \x01(\t\x12-\n\x0b\x63onnections\x18\x05 \x03(\x0b\x32\x18.PAM.PAMWebRtcConnection\"\xa7\x01\n\x13PAMWebRtcConnection\x12\x15\n\rconnectionUid\x18\x01 \x01(\x0c\x12\'\n\x04type\x18\x02 \x01(\x0e\x32\x19.PAM.WebRtcConnectionType\x12\x11\n\trecordUid\x18\x03 \x01(\x0c\x12\x10\n\x08userName\x18\x04 \x01(\t\x12\x11\n\tstartedOn\x18\x05 \x01(\x03\x12\x18\n\x10\x63onfigurationUid\x18\x06 \x01(\x0c\"Y\n\x14PAMOnlineControllers\x12\x12\n\ndeprecated\x18\x01 \x03(\x0c\x12-\n\x0b\x63ontrollers\x18\x02 \x03(\x0b\x32\x18.PAM.PAMOnlineController\"9\n\x10PAMRotateRequest\x12\x12\n\nrequestUid\x18\x01 \x01(\x0c\x12\x11\n\trecordUid\x18\x02 \x01(\x0c\"A\n\x16PAMControllersResponse\x12\'\n\x0b\x63ontrollers\x18\x01 \x03(\x0b\x32\x12.PAM.PAMController\"=\n\x13PAMRemoveController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x0f\n\x07message\x18\x02 \x01(\t\"L\n\x1bPAMRemoveControllerResponse\x12-\n\x0b\x63ontrollers\x18\x01 \x03(\x0b\x32\x18.PAM.PAMRemoveController\"=\n\x10PAMModifyRequest\x12)\n\noperations\x18\x01 \x03(\x0b\x32\x15.PAM.PAMDataOperation\"\x98\x01\n\x10PAMDataOperation\x12,\n\roperationType\x18\x01 \x01(\x0e\x32\x15.PAM.PAMOperationType\x12\x30\n\rconfiguration\x18\x02 \x01(\x0b\x32\x19.PAM.PAMConfigurationData\x12$\n\x07\x65lement\x18\x03 \x01(\x0b\x32\x13.PAM.PAMElementData\"e\n\x14PAMConfigurationData\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x0e\n\x06nodeId\x18\x02 \x01(\x03\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"E\n\x0ePAMElementData\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12\x11\n\tparentUid\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"m\n\x19PAMElementOperationResult\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12+\n\x06result\x18\x02 \x01(\x0e\x32\x1b.PAM.PAMOperationResultType\x12\x0f\n\x07message\x18\x03 \x01(\t\"B\n\x0fPAMModifyResult\x12/\n\x07results\x18\x01 \x03(\x0b\x32\x1e.PAM.PAMElementOperationResult\"x\n\nPAMElement\x12\x12\n\nelementUid\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\x0f\n\x07\x63reated\x18\x03 \x01(\x03\x12\x14\n\x0clastModified\x18\x04 \x01(\x03\x12!\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x0f.PAM.PAMElement\"#\n\x14PAMGenericUidRequest\x12\x0b\n\x03uid\x18\x01 \x01(\x0c\"%\n\x15PAMGenericUidsRequest\x12\x0c\n\x04uids\x18\x01 \x03(\x0c\"\xab\x01\n\x10PAMConfiguration\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x0e\n\x06nodeId\x18\x02 \x01(\x03\x12\x15\n\rcontrollerUid\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\x0f\n\x07\x63reated\x18\x05 \x01(\x03\x12\x14\n\x0clastModified\x18\x06 \x01(\x03\x12!\n\x08\x63hildren\x18\x07 \x03(\x0b\x32\x0f.PAM.PAMElement\"B\n\x11PAMConfigurations\x12-\n\x0e\x63onfigurations\x18\x01 \x03(\x0b\x32\x15.PAM.PAMConfiguration\"\xff\x01\n\rPAMController\x12\x15\n\rcontrollerUid\x18\x01 \x01(\x0c\x12\x16\n\x0e\x63ontrollerName\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65viceToken\x18\x03 \x01(\t\x12\x12\n\ndeviceName\x18\x04 \x01(\t\x12\x0e\n\x06nodeId\x18\x05 \x01(\x03\x12\x0f\n\x07\x63reated\x18\x06 \x01(\x03\x12\x14\n\x0clastModified\x18\x07 \x01(\x03\x12\x16\n\x0e\x61pplicationUid\x18\x08 \x01(\x0c\x12\x30\n\rappClientType\x18\t \x01(\x0e\x32\x19.Enterprise.AppClientType\x12\x15\n\risInitialized\x18\n \x01(\x08\"%\n\x12\x43ontrollerResponse\x12\x0f\n\x07payload\x18\x01 \x01(\t\"M\n\x1aPAMConfigurationController\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x15\n\rcontrollerUid\x18\x02 \x01(\x0c\"\xa3\x01\n\x17\x43onfigurationAddRequest\x12\x18\n\x10\x63onfigurationUid\x18\x01 \x01(\x0c\x12\x11\n\trecordKey\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12(\n\x0brecordLinks\x18\x04 \x03(\x0b\x32\x13.Records.RecordLink\x12#\n\x05\x61udit\x18\x05 \x01(\x0b\x32\x14.Records.RecordAudit*\x8e\x01\n\x14WebRtcConnectionType\x12\x0e\n\nCONNECTION\x10\x00\x12\n\n\x06TUNNEL\x10\x01\x12\x07\n\x03SSH\x10\x02\x12\x07\n\x03RDP\x10\x03\x12\x08\n\x04HTTP\x10\x04\x12\x07\n\x03VNC\x10\x05\x12\n\n\x06TELNET\x10\x06\x12\t\n\x05MYSQL\x10\x07\x12\x0e\n\nSQL_SERVER\x10\x08\x12\x0e\n\nPOSTGRESQL\x10\t*@\n\x10PAMOperationType\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\x0b\n\x07REPLACE\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*p\n\x16PAMOperationResultType\x12\x0f\n\x0bPOT_SUCCESS\x10\x00\x12\x15\n\x11POT_UNKNOWN_ERROR\x10\x01\x12\x16\n\x12POT_ALREADY_EXISTS\x10\x02\x12\x16\n\x12POT_DOES_NOT_EXIST\x10\x03*\\\n\x15\x43ontrollerMessageType\x12\x0f\n\x0b\x43MT_GENERAL\x10\x00\x12\x0e\n\nCMT_ROTATE\x10\x01\x12\x11\n\rCMT_DISCOVERY\x10\x02\x12\x0f\n\x0b\x43MT_CONNECT\x10\x03\x42\x1f\n\x18\x63om.keepersecurity.protoB\x03PAMb\x06proto3') +_WEBRTCCONNECTIONTYPE = DESCRIPTOR.enum_types_by_name['WebRtcConnectionType'] +WebRtcConnectionType = enum_type_wrapper.EnumTypeWrapper(_WEBRTCCONNECTIONTYPE) _PAMOPERATIONTYPE = DESCRIPTOR.enum_types_by_name['PAMOperationType'] PAMOperationType = enum_type_wrapper.EnumTypeWrapper(_PAMOPERATIONTYPE) _PAMOPERATIONRESULTTYPE = DESCRIPTOR.enum_types_by_name['PAMOperationResultType'] PAMOperationResultType = enum_type_wrapper.EnumTypeWrapper(_PAMOPERATIONRESULTTYPE) _CONTROLLERMESSAGETYPE = DESCRIPTOR.enum_types_by_name['ControllerMessageType'] ControllerMessageType = enum_type_wrapper.EnumTypeWrapper(_CONTROLLERMESSAGETYPE) +CONNECTION = 0 +TUNNEL = 1 +SSH = 2 +RDP = 3 +HTTP = 4 +VNC = 5 +TELNET = 6 +MYSQL = 7 +SQL_SERVER = 8 +POSTGRESQL = 9 ADD = 0 UPDATE = 1 REPLACE = 2 @@ -35,12 +47,14 @@ POT_DOES_NOT_EXIST = 3 CMT_GENERAL = 0 CMT_ROTATE = 1 -CMT_STREAM = 2 +CMT_DISCOVERY = 2 +CMT_CONNECT = 3 _PAMROTATIONSCHEDULE = DESCRIPTOR.message_types_by_name['PAMRotationSchedule'] _PAMROTATIONSCHEDULESRESPONSE = DESCRIPTOR.message_types_by_name['PAMRotationSchedulesResponse'] _PAMONLINECONTROLLER = DESCRIPTOR.message_types_by_name['PAMOnlineController'] +_PAMWEBRTCCONNECTION = DESCRIPTOR.message_types_by_name['PAMWebRtcConnection'] _PAMONLINECONTROLLERS = DESCRIPTOR.message_types_by_name['PAMOnlineControllers'] _PAMROTATEREQUEST = DESCRIPTOR.message_types_by_name['PAMRotateRequest'] _PAMCONTROLLERSRESPONSE = DESCRIPTOR.message_types_by_name['PAMControllersResponse'] @@ -61,7 +75,6 @@ _CONTROLLERRESPONSE = DESCRIPTOR.message_types_by_name['ControllerResponse'] _PAMCONFIGURATIONCONTROLLER = DESCRIPTOR.message_types_by_name['PAMConfigurationController'] _CONFIGURATIONADDREQUEST = DESCRIPTOR.message_types_by_name['ConfigurationAddRequest'] -_RELAYACCESSCREDS = DESCRIPTOR.message_types_by_name['RelayAccessCreds'] PAMRotationSchedule = _reflection.GeneratedProtocolMessageType('PAMRotationSchedule', (_message.Message,), { 'DESCRIPTOR' : _PAMROTATIONSCHEDULE, '__module__' : 'pam_pb2' @@ -83,6 +96,13 @@ }) _sym_db.RegisterMessage(PAMOnlineController) +PAMWebRtcConnection = _reflection.GeneratedProtocolMessageType('PAMWebRtcConnection', (_message.Message,), { + 'DESCRIPTOR' : _PAMWEBRTCCONNECTION, + '__module__' : 'pam_pb2' + # @@protoc_insertion_point(class_scope:PAM.PAMWebRtcConnection) + }) +_sym_db.RegisterMessage(PAMWebRtcConnection) + PAMOnlineControllers = _reflection.GeneratedProtocolMessageType('PAMOnlineControllers', (_message.Message,), { 'DESCRIPTOR' : _PAMONLINECONTROLLERS, '__module__' : 'pam_pb2' @@ -223,69 +243,64 @@ }) _sym_db.RegisterMessage(ConfigurationAddRequest) -RelayAccessCreds = _reflection.GeneratedProtocolMessageType('RelayAccessCreds', (_message.Message,), { - 'DESCRIPTOR' : _RELAYACCESSCREDS, - '__module__' : 'pam_pb2' - # @@protoc_insertion_point(class_scope:PAM.RelayAccessCreds) - }) -_sym_db.RegisterMessage(RelayAccessCreds) - if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\030com.keepersecurity.protoB\003PAM' - _PAMOPERATIONTYPE._serialized_start=2331 - _PAMOPERATIONTYPE._serialized_end=2395 - _PAMOPERATIONRESULTTYPE._serialized_start=2397 - _PAMOPERATIONRESULTTYPE._serialized_end=2509 - _CONTROLLERMESSAGETYPE._serialized_start=2511 - _CONTROLLERMESSAGETYPE._serialized_end=2583 + _WEBRTCCONNECTIONTYPE._serialized_start=2494 + _WEBRTCCONNECTIONTYPE._serialized_end=2636 + _PAMOPERATIONTYPE._serialized_start=2638 + _PAMOPERATIONTYPE._serialized_end=2702 + _PAMOPERATIONRESULTTYPE._serialized_start=2704 + _PAMOPERATIONRESULTTYPE._serialized_end=2816 + _CONTROLLERMESSAGETYPE._serialized_start=2818 + _CONTROLLERMESSAGETYPE._serialized_end=2910 _PAMROTATIONSCHEDULE._serialized_start=51 _PAMROTATIONSCHEDULE._serialized_end=182 _PAMROTATIONSCHEDULESRESPONSE._serialized_start=184 _PAMROTATIONSCHEDULESRESPONSE._serialized_end=259 - _PAMONLINECONTROLLER._serialized_start=261 - _PAMONLINECONTROLLER._serialized_end=362 - _PAMONLINECONTROLLERS._serialized_start=364 - _PAMONLINECONTROLLERS._serialized_end=453 - _PAMROTATEREQUEST._serialized_start=455 - _PAMROTATEREQUEST._serialized_end=512 - _PAMCONTROLLERSRESPONSE._serialized_start=514 - _PAMCONTROLLERSRESPONSE._serialized_end=579 - _PAMREMOVECONTROLLER._serialized_start=581 - _PAMREMOVECONTROLLER._serialized_end=642 - _PAMREMOVECONTROLLERRESPONSE._serialized_start=644 - _PAMREMOVECONTROLLERRESPONSE._serialized_end=720 - _PAMMODIFYREQUEST._serialized_start=722 - _PAMMODIFYREQUEST._serialized_end=783 - _PAMDATAOPERATION._serialized_start=786 - _PAMDATAOPERATION._serialized_end=938 - _PAMCONFIGURATIONDATA._serialized_start=940 - _PAMCONFIGURATIONDATA._serialized_end=1041 - _PAMELEMENTDATA._serialized_start=1043 - _PAMELEMENTDATA._serialized_end=1112 - _PAMELEMENTOPERATIONRESULT._serialized_start=1114 - _PAMELEMENTOPERATIONRESULT._serialized_end=1223 - _PAMMODIFYRESULT._serialized_start=1225 - _PAMMODIFYRESULT._serialized_end=1291 - _PAMELEMENT._serialized_start=1293 - _PAMELEMENT._serialized_end=1413 - _PAMGENERICUIDREQUEST._serialized_start=1415 - _PAMGENERICUIDREQUEST._serialized_end=1450 - _PAMGENERICUIDSREQUEST._serialized_start=1452 - _PAMGENERICUIDSREQUEST._serialized_end=1489 - _PAMCONFIGURATION._serialized_start=1492 - _PAMCONFIGURATION._serialized_end=1663 - _PAMCONFIGURATIONS._serialized_start=1665 - _PAMCONFIGURATIONS._serialized_end=1731 - _PAMCONTROLLER._serialized_start=1734 - _PAMCONTROLLER._serialized_end=1989 - _CONTROLLERRESPONSE._serialized_start=1991 - _CONTROLLERRESPONSE._serialized_end=2028 - _PAMCONFIGURATIONCONTROLLER._serialized_start=2030 - _PAMCONFIGURATIONCONTROLLER._serialized_end=2107 - _CONFIGURATIONADDREQUEST._serialized_start=2110 - _CONFIGURATIONADDREQUEST._serialized_end=2273 - _RELAYACCESSCREDS._serialized_start=2275 - _RELAYACCESSCREDS._serialized_end=2329 + _PAMONLINECONTROLLER._serialized_start=262 + _PAMONLINECONTROLLER._serialized_end=410 + _PAMWEBRTCCONNECTION._serialized_start=413 + _PAMWEBRTCCONNECTION._serialized_end=580 + _PAMONLINECONTROLLERS._serialized_start=582 + _PAMONLINECONTROLLERS._serialized_end=671 + _PAMROTATEREQUEST._serialized_start=673 + _PAMROTATEREQUEST._serialized_end=730 + _PAMCONTROLLERSRESPONSE._serialized_start=732 + _PAMCONTROLLERSRESPONSE._serialized_end=797 + _PAMREMOVECONTROLLER._serialized_start=799 + _PAMREMOVECONTROLLER._serialized_end=860 + _PAMREMOVECONTROLLERRESPONSE._serialized_start=862 + _PAMREMOVECONTROLLERRESPONSE._serialized_end=938 + _PAMMODIFYREQUEST._serialized_start=940 + _PAMMODIFYREQUEST._serialized_end=1001 + _PAMDATAOPERATION._serialized_start=1004 + _PAMDATAOPERATION._serialized_end=1156 + _PAMCONFIGURATIONDATA._serialized_start=1158 + _PAMCONFIGURATIONDATA._serialized_end=1259 + _PAMELEMENTDATA._serialized_start=1261 + _PAMELEMENTDATA._serialized_end=1330 + _PAMELEMENTOPERATIONRESULT._serialized_start=1332 + _PAMELEMENTOPERATIONRESULT._serialized_end=1441 + _PAMMODIFYRESULT._serialized_start=1443 + _PAMMODIFYRESULT._serialized_end=1509 + _PAMELEMENT._serialized_start=1511 + _PAMELEMENT._serialized_end=1631 + _PAMGENERICUIDREQUEST._serialized_start=1633 + _PAMGENERICUIDREQUEST._serialized_end=1668 + _PAMGENERICUIDSREQUEST._serialized_start=1670 + _PAMGENERICUIDSREQUEST._serialized_end=1707 + _PAMCONFIGURATION._serialized_start=1710 + _PAMCONFIGURATION._serialized_end=1881 + _PAMCONFIGURATIONS._serialized_start=1883 + _PAMCONFIGURATIONS._serialized_end=1949 + _PAMCONTROLLER._serialized_start=1952 + _PAMCONTROLLER._serialized_end=2207 + _CONTROLLERRESPONSE._serialized_start=2209 + _CONTROLLERRESPONSE._serialized_end=2246 + _PAMCONFIGURATIONCONTROLLER._serialized_start=2248 + _PAMCONFIGURATIONCONTROLLER._serialized_end=2325 + _CONFIGURATIONADDREQUEST._serialized_start=2328 + _CONFIGURATIONADDREQUEST._serialized_end=2491 # @@protoc_insertion_point(module_scope) diff --git a/keepercommander/record_facades.py b/keepercommander/record_facades.py index 6f5b3036d..04b5ada0c 100644 --- a/keepercommander/record_facades.py +++ b/keepercommander/record_facades.py @@ -87,6 +87,40 @@ def setter(obj, value): field.value.clear() return setter +def boolean_getter(name): # type: (str) -> Callable[[TypedRecordFacade], bool] + def getter(obj): + field = getattr(obj, name) + if isinstance(field, TypedField): + value = field.value[0] if len(field.value) > 0 else None + if value is None: + return None + elif isinstance(value, bool) is True: + return value + + if str(value).lower() in ['true', 'yes', '1', 'on']: + return True + elif str(value).lower() in ['false', 'no', '0', 'off']: + return False + return None + return getter + +def boolean_setter(name): # type: (str) -> Callable[[Any, str], None] + def setter(obj, value): + field = getattr(obj, name) + if isinstance(field, TypedField): + if value is not None: + if isinstance(value, bool) is not True: + if str(value).lower() in ['true', 'yes', '1', 'on']: + value = True + elif str(value).lower() in ['false', 'no', '0', 'off']: + value = False + if len(field.value) > 0: + field.value[0] = value + else: + field.value.append(value) + else: + field.value.clear() + return setter def string_element_getter(name, element_name): # type: (str, str) -> Callable[[Any], str] def getter(obj): diff --git a/keepercommander/utils.py b/keepercommander/utils.py index 2acd410ba..01b47de87 100644 --- a/keepercommander/utils.py +++ b/keepercommander/utils.py @@ -325,7 +325,6 @@ def size_to_str(size): # type: (int) -> str size = size / 1024 return f'{size:,.2f} Gb' - def parse_totp_uri(uri): # type: (str) -> Dict[str, Union[str, int, None]] def parse_int(val): return val and int(val) @@ -363,3 +362,15 @@ def decode_uri_component(component): # type: (str) -> str } return result + +def value_to_boolean(value): + """ + Replacement for distutils.util.strtobool + """ + value = str(value) + if value.lower() in ['true', 'yes', 'on', '1']: + return True + elif value.lower() in ['false', 'no', 'off', '0']: + return False + else: + return None diff --git a/keepercommander/vault.py b/keepercommander/vault.py index bb3df936d..bed0c78e5 100644 --- a/keepercommander/vault.py +++ b/keepercommander/vault.py @@ -362,7 +362,7 @@ def export_host_field(value): # type: (dict) -> Optional[str] port = value.get('port') or '' if host or port: if port: - host += ':' + port + host += ':' + str(port) return host @staticmethod diff --git a/libs/discovery_common-1.0.26-py3-none-any.whl b/libs/discovery_common-1.0.26-py3-none-any.whl new file mode 100644 index 000000000..2ad348f83 Binary files /dev/null and b/libs/discovery_common-1.0.26-py3-none-any.whl differ diff --git a/libs/keeper_dag-1.0.20-py3-none-any.whl b/libs/keeper_dag-1.0.20-py3-none-any.whl new file mode 100644 index 000000000..e08abaa0a Binary files /dev/null and b/libs/keeper_dag-1.0.20-py3-none-any.whl differ diff --git a/requirements.txt b/requirements.txt index 04deedae8..a0d29ef8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,17 @@ requests>=2.30.0; python_version>='3.7' cryptography>=39.0.1 protobuf>=3.19.0 keeper-secrets-manager-core>=16.6.0 -aiortc; python_version>='3.8' and python_version<'3.13' \ No newline at end of file +aiortc; python_version>='3.8' and python_version<'3.13' + +pydantic>=2.6.4 + +# pip uninstall keeper-dag -y +# python3 setup.py wheel --whlsrc ~/src/keeper-dag --libdir $PWD/libs --reqfiles $PWD/requirements.txt +# pip install $(ls libs/keeper_dag-*) +./libs/keeper_dag-1.0.20-py3-none-any.whl + + +# pip uninstall discovery-common -y +# python3 setup.py wheel --whlsrc ~/src/discovery-common --libdir $PWD/libs --reqfiles $PWD/requirements.txt +# pip install $(ls libs/discovery_common-*) +./libs/discovery_common-1.0.26-py3-none-any.whl diff --git a/setup.py b/setup.py index 6b40b52bf..6f8e94a0f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,102 @@ from setuptools import setup +from setuptools.command.install import install as install_command +import os +import subprocess +import shutil +import re + + +class Wheel(install_command): + + user_options = install_command.user_options + [ + ('whlsrc=', None, "Build a wheel for the python code that is in this directory. Copy into 'libs' directory."), + ('libdir=', None, "The directory to put the whl files."), + ('reqfiles=', None, "List of requirement.txt to update."), + ] + + def initialize_options(self): + install_command.initialize_options(self) + self.whlsrc = None + self.libdir = None + self.reqfiles = None + + def finalize_options(self): + install_command.finalize_options(self) + + def run(self): + global whlsrc + global libdir + global reqfiles + whlsrc = self.whlsrc + libdir = self.libdir + reqfiles = self.reqfiles + + if isinstance(reqfiles, list) is False: + reqfiles = [reqfiles] + + current_dir = os.getcwd() + try: + # Get existing fiels in the lib directory. + os.chdir(self.libdir) + sp = subprocess.run(["ls"], capture_output=True, text=True) + existing_whls = [] + for file in sp.stdout.split("\n"): + if file.endswith("whl") is True: + existing_whls.append(file) + + # Installed required modules and build a wheel + os.chdir(whlsrc) + subprocess.run(["pip3", "install", "-r", "requirements.txt"]) + subprocess.run(["python3", "setup.py", "bdist_wheel"]) + + # Find the whl file in the dist folder. + os.chdir(os.path.join(whlsrc, "dist")) + sp = subprocess.run(["ls"], capture_output=True, text=True) + wheel_file = None + for file in sp.stdout.split("\n"): + if file.endswith("whl") is True: + wheel_file = file + break + if wheel_file is None: + raise ValueError(f"Cannot find a whl file in the dist directory of the {whlsrc} project.") + + # Copy the whl to the lib directory + subprocess.run(["cp", wheel_file, self.libdir]) + + project_name = wheel_file[:wheel_file.index("-")] + + # Remove old versions of the wheel. + os.chdir(self.libdir) + for existing_whl in existing_whls: + if existing_whl.startswith(project_name) is False: + continue + if existing_whl == wheel_file: + continue + os.unlink(existing_whl) + + for req in reqfiles: + shutil.copy(req, f"{req}.bak") + requirement_data = [] + with open(req, "r") as fh: + requirement_data = fh.readlines() + fh.close() + + pattern = re.compile(re.escape(project_name) + "-.*?.whl" ) + with open(req, "w") as fh: + for line in requirement_data: + line = re.sub(pattern, wheel_file, line) + fh.write(line) + fh.close() + os.unlink(f"{req}.bak") + + finally: + os.chdir(current_dir) + + if __name__ == '__main__': - setup() + setup( + cmdclass={ + 'wheel': Wheel + } + ) diff --git a/unit-tests/pam/test_pam_rotation.py b/unit-tests/pam/test_pam_rotation.py new file mode 100644 index 000000000..9d35ea86f --- /dev/null +++ b/unit-tests/pam/test_pam_rotation.py @@ -0,0 +1,506 @@ +import json +import unittest +from datetime import datetime +from unittest.mock import patch, MagicMock + +import requests +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat + +from keepercommander import crypto, utils +from keepercommander.commands.discoveryrotation import (PAMCreateRecordRotationCommand, PAMListRecordRotationCommand, + PAMGatewayListCommand) +from keepercommander.error import CommandError +import keepercommander.vault as vault + + +def create_mock_params_and_record(record_type='pamUser'): + mock_params = MagicMock() + mock_params.rest_context.server_key_id = 8 + mock_params.session_token = 'base64_encoded_session_token' # Mock a base64 encoded session token + mock_params.record_cache = {'record_uid': MagicMock(record_type='pamUser')} + mock_params.subfolder_record_cache = {'folder_uid': ['record_uid']} + mock_params.folder_cache = {'folder_uid': MagicMock()} + mock_params.record_rotation_cache = { + 'record_uid': { + 'pwd_complexity': 'eyJ0eXBlIjogInBhc3N3b3JkX2NvbXBsZXhpdHkiLCAidmFsdWUiOiAiY29tcGxleGl0eV92YWx1ZSJ9', + 'configuration_uid': 'config_uid', + 'schedule': '[]', + 'resourceUid': 'resource_uid', + 'revision': 1 # Ensure revision is set + } + } + mock_params.rest_context.server_base = 'https://fake.keepersecurity.com' # Mock URL as string + + mock_typed_record = MagicMock(spec=vault.TypedRecord) + mock_typed_record.record_type = record_type + mock_typed_record.record_uid = 'record_uid' + mock_typed_record.title = 'Mock Title' # Add the title attribute + mock_typed_record.record_key = b'\x00' * 16 # Add the record_key attribute + + return mock_params, mock_typed_record + + +def create_mock_params(): + mock_params = MagicMock() + mock_params.rest_context.server_key_id = 8 + mock_params.session_token = 'base64_encoded_session_token' + mock_params.record_cache = { + 'record_uid': { + 'data_unencrypted': json.dumps({'title': 'Mock Title', 'type': 'pamMachine'}) + } + } + mock_params.subfolder_record_cache = {'folder_uid': ['record_uid']} + mock_params.folder_cache = {'folder_uid': MagicMock()} + mock_params.rest_context.server_base = 'https://fake.keepersecurity.com' + + return mock_params + + +class TestPAMCreateRecordRotationCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() + self.transmission_key = b'transmission_key' + self.session_token = b'encrypted_session_token' + self.private_key = ec.generate_private_key(ec.SECP256R1()) + self.public_key = self.private_key.public_key() + + # Serialize and deserialize the public key to ensure compatibility + public_key_bytes = self.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) + loaded_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_key_bytes) + + self.encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, loaded_public_key) + self.encrypted_session_token = crypto.encrypt_aes_v2(self.session_token, self.transmission_key) + + def test_parser(self): + args = self.parser.parse_args(['--record', 'record_uid', '--force']) + self.assertEqual(args.record_name, 'record_uid') + self.assertTrue(args.force) + + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_folder(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'folder_name': 'folder_uid', + 'force': True # Add force to the kwargs + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) + + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_no_record(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() + mock_params.record_cache = {} + + kwargs = { + 'record_name': 'non_existent_record', + 'force': True # Add force to the kwargs + } + + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) + + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_invalid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() + + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': 'invalid_complexity', + 'force': True # Add force to the kwargs + } + + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) + + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': '32,5,5,5,5', + 'force': True # Add force to the kwargs + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) + + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_record(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'force': True # Add force to the kwargs + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) + + +class TestPAMResourceRotateCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--record', "abcdefg", '--enable']) + self.assertEqual(args.record_name, 'abcdefg') + self.assertTrue(args.enable) + + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True + + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') + + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record + + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] + + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid' + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + def test_execute_with_invalid_uid(self, mock_load): + mock_params, _ = create_mock_params_and_record('pamMachine') + + kwargs = { + 'record_name': 'invalid_uid', + 'enable': True + } + + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) + + @patch('keepercommander.vault.KeeperRecord.load') + def test_execute_with_invalid_record_type(self, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record(record_type='invalid_type') + mock_load.return_value = mock_typed_record + + kwargs = { + 'record_name': 'record_uid', + 'enable': True + } + + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) + + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_disable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True + + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') + + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record + + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] + + kwargs = { + 'record_name': 'record_uid', + 'disable': True, + 'config_uid': 'config_uid' + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable_and_admin(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True + + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') + + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record + + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] + + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid', + 'admin': 'admin_uid' + } + + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + mock_dag_instance.link_user_to_resource.assert_called_with('admin_uid', 'record_uid', is_admin=True) + + +class TestPAMListRecordRotationCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMListRecordRotationCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose']) + self.assertTrue(args.is_verbose) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [ + MagicMock( + recordUid=utils.base64_url_decode('record_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + configurationUid=utils.base64_url_decode('config_uid'), + noSchedule=False, + scheduleData='RotateActionJob|daily.0.12.1' + ) + ] + + mock_get_all_gateways.return_value = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid'), controllerName='Controller Name') + ] + + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_pam_configurations_get_all.return_value = [ + {'record_uid': 'config_uid', 'data_unencrypted': json.dumps({'title': 'Config Title', 'type': 'pamConfig'})} + ] + + mock_pam_decrypt_configuration_data.return_value = { + 'title': 'Config Title', + 'type': 'pamConfig' + } + + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_with_no_schedules(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [] + + mock_get_all_gateways.return_value = [] + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_pam_configurations_get_all.return_value = [] + + mock_pam_decrypt_configuration_data.return_value = {} + + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + +class TestPAMGatewayListCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMGatewayListCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose', '--force']) + self.assertTrue(args.is_verbose) + self.assertTrue(args.is_force) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.KSMCommand.get_app_record') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_get_app_record, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + mock_get_app_record.return_value = { + 'data_unencrypted': json.dumps({'title': 'App Title'}) + } + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + self.assertTrue(mock_get_app_record.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_router_down(self, mock_dump_report_data, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Simulate a connection error + mock_router_get_connected_gateways.side_effect = requests.exceptions.ConnectionError + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + def test_execute_no_gateways(self, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_get_all_gateways.return_value = [] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) \ No newline at end of file diff --git a/unit-tests/pam-tunnel/test_pam_tunnel.py b/unit-tests/pam/test_pam_tunnel.py similarity index 100% rename from unit-tests/pam-tunnel/test_pam_tunnel.py rename to unit-tests/pam/test_pam_tunnel.py diff --git a/unit-tests/pam-tunnel/test_private_tunnel.py b/unit-tests/pam/test_private_tunnel.py similarity index 98% rename from unit-tests/pam-tunnel/test_private_tunnel.py rename to unit-tests/pam/test_private_tunnel.py index 1754021de..1d4e19ff5 100644 --- a/unit-tests/pam-tunnel/test_private_tunnel.py +++ b/unit-tests/pam/test_private_tunnel.py @@ -30,7 +30,6 @@ async def asyncSetUp(self): self.private_key, self.private_key_str = new_private_key() self.logger = mock.MagicMock(spec=logging) - self.kill_server_event = asyncio.Event() self.tunnel_symmetric_key = utils.generate_aes_key() self.pc = mock.MagicMock(sepc=WebRTCConnection) self.pc.endpoint_name = self.endpoint_name @@ -143,8 +142,9 @@ async def test_forward_data_to_local_error(self): async def test_process_close_connection_message(self): with mock.patch.object(self.pte, 'close_connection', new_callable=mock.AsyncMock) as mock_close: - await self.pte.process_control_message(ControlMessage.CloseConnection, - int.to_bytes(1, byteorder='big', length=CONNECTION_NO_LENGTH)) + data = (int.to_bytes(1, byteorder='big', length=CONNECTION_NO_LENGTH) + + int_to_bytes(CloseConnectionReasons.Normal.value)) + await self.pte.process_control_message(ControlMessage.CloseConnection, data) mock_close.assert_called_with(1, CloseConnectionReasons.Normal) async def test_process_pong_message(self): @@ -160,9 +160,9 @@ async def test_process_pong_message(self): async def test_process_ping_message(self): with mock.patch.object(self.pte, 'send_control_message', new_callable=mock.AsyncMock) as mock_send: self.pte.logger = mock.MagicMock() - await self.pte.process_control_message(ControlMessage.Ping, b'') - self.pte.logger.debug.assert_called_with('Endpoint TestEndpoint: Received ping request') - mock_send.assert_called_with(ControlMessage.Pong, b'\x00') + await self.pte.process_control_message(ControlMessage.Ping, b'\x00\x00\x00\x00') + self.pte.logger.debug.assert_called_with('Endpoint TestEndpoint: Received Ping for 0') + mock_send.assert_called_with(ControlMessage.Pong, b'\x00\x00\x00\x00') async def test_start_server(self): with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_open_connection, \ diff --git a/unit-tests/pam/test_socks_server.py b/unit-tests/pam/test_socks_server.py new file mode 100644 index 000000000..744ed4a07 --- /dev/null +++ b/unit-tests/pam/test_socks_server.py @@ -0,0 +1,204 @@ +import asyncio +import logging +import sys +import unittest +from unittest import mock +from unittest.mock import Mock + +from keepercommander.commands.tunnel.port_forward.endpoint import WebRTCConnection, \ + CloseConnectionReasons + +if sys.version_info >= (3, 8): + from keepercommander.commands.tunnel.port_forward.endpoint import SOCKS5Server + + + class TestSOCKSServer(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + + # Set up asyncio event loop for testing + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.host = 'localhost' + self.port = 8080 + self.kill_server_event = asyncio.Event() + self.connect_task = mock.MagicMock(spec=asyncio.Task) + self.logger = mock.MagicMock(spec=logging) + self.kill_server_event = asyncio.Event() + self.pc = mock.MagicMock(sepc=WebRTCConnection) + self.pc.endpoint_name = 'TestEndpoint' + self.pc.data_channel.readyState = 'open' + self.pc.data_channel.bufferedAmount = 0 + self.print_ready_event = asyncio.Event() + self.pte = SOCKS5Server(self.host, self.port, self.pc, self.print_ready_event, self.logger, + self.connect_task, self.kill_server_event) + + self.reader = mock.AsyncMock() + self.writer = mock.AsyncMock() + + async def asyncTearDown(self): + await self.pte.stop_server(CloseConnectionReasons.Normal) # ensure the server is stopped after test + + # def test_username_password_authenticate(self): + # # Example test for the username/password authentication method + # + # # Mock reader and writer streams + # reader = mock.AsyncMock() + # writer = mock.AsyncMock() + # + # # Mock the reader to simulate client sending authentication data + # reader.readexactly.side_effect = [ + # b'\x01', # Auth version + # b'\x0A', # Username length + # b'defaultuser', # Username + # b'\x0B', # Password length + # b'defaultpass' # Password + # ] + # + # # Run the coroutine and get the result + # result = asyncio.run(self.pte.username_password_authenticate(reader, writer)) + # + # # Assert the authentication was successful + # self.assertTrue(result) + + # def test_successful_authentication(self): + # # Setup mock to simulate reading data from the reader + # self.reader.readexactly.side_effect = [ + # b'\x01', # Auth version + # b'\x0A', # Username length + # b'defaultuser', # Username + # b'\x0B', # Password length + # b'defaultpass' # Password + # ] + # + # result = asyncio.run(self.pte.username_password_authenticate(self.reader, self.writer)) + # self.assertTrue(result) + # # Check for success response + # self.writer.write.assert_called_with(b'\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00') + + # def test_failed_authentication_wrong_credentials(self): + # # Setup mock with incorrect credentials + # self.reader.readexactly.side_effect = [ + # b'\x01', # Auth version + # b'\x07', # Incorrect username length + # b'wrong', # Incorrect username + # b'\x06', # Incorrect password length + # b'123456' # Incorrect password + # ] + # result = asyncio.run(self.pte.username_password_authenticate(self.reader, self.writer)) + # self.assertFalse(result) + # # Check for failure response + # self.writer.write.assert_called_with(b'\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00') + # + # def test_failed_authentication_bad_version(self): + # self.reader.readexactly.side_effect = [b'\x02'] # Incorrect auth version + # + # result = asyncio.run(self.pte.username_password_authenticate(self.reader, self.writer)) + # self.assertFalse(result) + # # This test assumes the function just returns False without sending a specific response for bad version + + async def test_handle_connection(self): + + self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + # Simulate the client's greeting and authentication method request + self.reader.read.side_effect = [ + b'\x05\x01', # SOCKS version 5, 1 authentication method supported + b'\x05\x01\x00\x03' # SOCKS version 5, 1 method selected, No Auth + ] + self.reader.readexactly.side_effect = [ + b'\x00\x02', # No Auth and Username/Password methods + b'\x0b', # length of the domain name + b'example.com', # Domain name + b'\x00\x50' # port 80 + ] + await self.pte.handle_connection(self.reader, self.writer) + self.assertTrue(self.writer.write.called) + + # async def test_handle_connection_with_authentication(self): + # self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + # # Simulate the client's greeting indicating Username/Password method supported and chosen + # self.reader.readexactly.side_effect = [ + # b'\x02', # Username/Password method + # b'\x01', # Auth version + # b'\x09', # Username length + # b'defaultuser', # Username + # b'\x08', # Password length + # b'defaultpass', # Password + # b'\x0b', # length of the domain name + # b'example.com', # Domain name + # b'\x00\x50' # port 80 + # ] + # + # self.reader.read.side_effect = [ + # b'\x05\x01', # SOCKS version 5, 2 authentication methods supported + # b'\x05\x01\x00\x03', # SOCKS version 5, 1 method selected, No Auth, domain name address type + # ] + # + # await self.pte.handle_connection(self.reader, self.writer) + # # Example assertion, adjust based on your protocol implementation + # # Assert success authentication response + # self.writer.write.assert_any_call(b'\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00') + + async def test_handle_connection_with_unsupported_auth(self): + self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + # Simulate the client's greeting with an unsupported authentication method + self.reader.readexactly.side_effect = [ + b'\x03' # An unsupported authentication method, e.g., GSSAPI + ] + + self.reader.read.side_effect = [ + b'\x05\x01', # SOCKS version 5, 1 authentication methods supported + ] + + await self.pte.handle_connection(self.reader, self.writer) + self.writer.write.assert_any_call(b'\x05\xff') # Response indicating no acceptable methods + + async def test_handle_connection_with_invalid_version(self): + self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + # Simulate the client's greeting with an invalid SOCKS version + self.reader.read.side_effect = [ + b'\x03\x01', # Invalid SOCKS version, e.g., SOCKS4 + b'\x00' # No Auth method + ] + + await self.pte.handle_connection(self.reader, self.writer) + # Response indicating no acceptable methods + self.writer.write.assert_any_call(b'\x05\x01\x00\x01\x00\x00\x00\x00\x00\x00') + self.writer.close.assert_called() # Connection closed without proceeding + + async def test_handle_connection_with_unsupported_address_type(self): + self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + # Simulate a SOCKS connection request with an unsupported address type + self.reader.readexactly.side_effect = [ + b'\x00\x02', # No Auth and Username/Password methods + b'\x05', # Unsupported address type (e.g., X.25) + ] + self.reader.read.side_effect = [ + b'\x05\x02', # SOCKS version 5, 2 authentication methods supported + b'\x05\x01\x00\x05', # SOCKS version 5, 1 method selected, No Auth + ] + + await self.pte.handle_connection(self.reader, self.writer) + # Server response error for unsupported address type + self.writer.write.assert_any_call(b'\x05\x08\x00\x01\x00\x00\x00\x00\x00\x00') + + async def test_unsupported_command(self): + # Simulate a SOCKS request with an unsupported command + self.reader.readexactly.side_effect = [ + b'\x00', # No Auth method + b'\x0b', # Length of the domain name + b'example.com', # Domain name + b'\x00', + b'\x50' # Port 80 + ] + + self.reader.read.side_effect = [ + b'\x05\x01', # SOCKS version 5, 2 authentication methods supported + b'\x05\x02\x00\x03', # Unsupported command (0x02 for BIND, as an example), domain name address type + ] + + self.writer.get_extra_info = Mock(return_value=('127.0.0.1', 12345)) + + await self.pte.handle_connection(self.reader, self.writer) + # Check for a response indicating a command not supported error + # 07 indicating a command not supported error + self.writer.write.assert_any_call(b'\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00')