diff --git a/openhexa/sdk/workspaces/current_workspace.py b/openhexa/sdk/workspaces/current_workspace.py index f132759..b444ad9 100644 --- a/openhexa/sdk/workspaces/current_workspace.py +++ b/openhexa/sdk/workspaces/current_workspace.py @@ -5,7 +5,7 @@ import os import typing -from dataclasses import make_dataclass +from dataclasses import fields, make_dataclass from warnings import warn from openhexa.utils import stringcase @@ -153,6 +153,35 @@ def tmp_path(self) -> str: # We can remove this once we deprecate this way of running pipelines return os.environ["WORKSPACE_TMP_PATH"] if "WORKSPACE_TMP_PATH" in os.environ else "/home/hexa/tmp" + def _get_local_connection_fields(self, env_variable_prefix: str): + connection_fields = {} + connection_type = os.getenv(env_variable_prefix).upper() + + # Get fields for the connection type + _fields = fields(ConnectionClasses[connection_type]) + + if _fields: + for field in _fields: + env_var = f"{env_variable_prefix}_{field.name.upper()}" + connection_fields[field.name] = os.getenv(env_var) + else: + # custom connections + prefix = f"{env_variable_prefix}_" + connection_fields = { + key[len(prefix) :].lower(): val for key, val in os.environ.items() if key.startswith(prefix) + } + + # need to map the correct name for s3 and postgres connection to ensure compatibility + # with the one coming from the API + if connection_type == "S3": + connection_fields.pop("secret_access_key") + connection_fields["access_key_secret"] = os.getenv(f"{env_variable_prefix}_ACCESS_KEY_SECRET") + if connection_type == "POSTGRESQL": + connection_fields.pop("database_name") + connection_fields["db_name"] = os.getenv(f"{env_variable_prefix}_DB_NAME") + + return connection_fields + def get_connection( self, identifier: str ) -> typing.Union[ @@ -160,7 +189,7 @@ def get_connection( PostgreSQLConnection, IASOConnection, S3Connection, - S3Connection, + GCSConnection, CustomConnection, None, ]: @@ -181,7 +210,7 @@ def get_connection( ValueError If the connection does not exist """ - fields = {} + connection_fields = {} connection_type = None if self._connected: response = graphql( @@ -203,45 +232,45 @@ def get_connection( raise ValueError(f"Connection {identifier} does not exist.") for d in data["fields"]: - fields[d.get("code")] = d.get("value") + connection_fields[d.get("code")] = d.get("value") - connection_type = data["type"] + connection_type = data["type"].upper() else: try: env_variable_prefix = stringcase.constcase(identifier.lower()) - for key, val in os.environ.items(): - if key.startswith(f"{env_variable_prefix}_"): - field_name = key[len(f"{env_variable_prefix}_") :].lower() - fields[field_name] = val - - connection_type = os.environ[f"{env_variable_prefix}"] + connection_type = os.environ[f"{env_variable_prefix}"].upper() + connection_fields = self._get_local_connection_fields(env_variable_prefix) except KeyError: raise ValueError if not connection_type: raise ValueError(f"Connection {identifier} does not exist.") - connection_type = connection_type.upper() - if connection_type in ConnectionClasses.keys(): - if connection_type == "S3": - secret_access_key = fields.pop("access_key_secret") - return S3Connection(secret_access_key=secret_access_key, **fields) - - if connection_type == "POSTGRESQL": - db_name = fields.pop("db_name") - port = int(fields.pop("port")) - return PostgreSQLConnection(database_name=db_name, port=port, **fields) - - if connection_type == "CUSTOM": - dataclass = make_dataclass( - stringcase.pascalcase(identifier), - fields.keys(), - bases=(CustomConnection,), - repr=False, - ) - return dataclass(**fields) - - return ConnectionClasses[connection_type](**fields) + # In connected mode (API call) the secret_access_key field and db_name name are + # different from the offline ones + if connection_type == "S3": + secret_access_key = connection_fields.pop("access_key_secret") + return S3Connection(secret_access_key=secret_access_key, **connection_fields) + + if connection_type == "POSTGRESQL": + db_name = connection_fields.pop("db_name") + port = int(connection_fields.pop("port")) + return PostgreSQLConnection( + database_name=db_name, + port=port, + **connection_fields, + ) + + if connection_type == "CUSTOM": + dataclass = make_dataclass( + stringcase.pascalcase(identifier), + connection_fields.keys(), + bases=(CustomConnection,), + repr=False, + ) + return dataclass(**connection_fields) + + return ConnectionClasses[connection_type](**connection_fields) def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection: """Get a DHIS2 connection by its identifier. diff --git a/tests/test_workspace.py b/tests/test_workspace.py index bdfb6f3..10986df 100644 --- a/tests/test_workspace.py +++ b/tests/test_workspace.py @@ -67,6 +67,36 @@ def test_workspace_dhis2_connection(self, workspace): assert re.search("password", repr(dhis2_connection), re.IGNORECASE) is None assert re.search("password", str(dhis2_connection), re.IGNORECASE) is None + def test_workspace_dhis2_connection_similar_prefix(self, workspace): + """Base test case for DHIS2 connections.""" + identifier = "polio" + identifier_2 = "polio-test" + + env_variable_prefix = stringcase.constcase(identifier) + env_variable_prefix_2 = stringcase.constcase(identifier_2) + + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}": "dhis2", + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix_2}": "dhis2", + f"{env_variable_prefix_2}_URL": "url_2", + f"{env_variable_prefix_2}_USERNAME": "username_2", + f"{env_variable_prefix_2}_PASSWORD": "password_2", + }, + ): + dhis2_connection = workspace.dhis2_connection(identifier=identifier) + assert dhis2_connection.url == url + assert dhis2_connection.username == username + assert dhis2_connection.password == password + def test_workspace_postgresql_connection_not_exist(self, workspace): """Does not exist test case for PostgreSQL connections.""" identifier = "polio-ff3a0d" @@ -316,16 +346,21 @@ def test_workspace_get_connection(self, workspace): """Test get connection.""" data = { "connectionBySlug": { - "type": "CUSTOM", - "fields": [{"code": "field_1", "value": "field_1_value"}], + "type": "S3", + "fields": [ + {"code": "bucket_name", "value": "bucket_name"}, + {"code": "access_key_id", "value": "access_key_id"}, + {"code": "access_key_secret", "value": "secret_access_key"}, + ], } } + with mock.patch( "openhexa.sdk.workspaces.current_workspace.graphql", return_value=data, ): - connection = workspace.get_connection("random") - assert isinstance(connection, CustomConnection) + connection = workspace.get_connection("s3-connection") + assert isinstance(connection, S3Connection) def test_workspace_dhis2_connection_not_exist(self, workspace): """Does not exist test case for DHIS2 connections.""" @@ -401,6 +436,7 @@ def test_workspace_s3_connection(monkeypatch, workspace): data = S3Connection(access_key_id, secret_access_key, bucket_name) with mock.patch.object(workspace, "get_connection", return_value=data): s3_connection = workspace.s3_connection(identifier=identifier) + assert s3_connection.secret_access_key == secret_access_key assert s3_connection.access_key_id == access_key_id assert s3_connection.bucket_name == bucket_name