Skip to content

Commit

Permalink
fix(Connections): fix shallow matching when retrieving connection fie…
Browse files Browse the repository at this point in the history
…lds (#209)
  • Loading branch information
cheikhgwane authored Sep 25, 2024
1 parent 80a73be commit 0bfd5f7
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 36 deletions.
93 changes: 61 additions & 32 deletions openhexa/sdk/workspaces/current_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import typing
from dataclasses import make_dataclass
from dataclasses import fields, make_dataclass
from warnings import warn

from openhexa.utils import stringcase
Expand Down Expand Up @@ -153,14 +153,43 @@ def tmp_path(self) -> str:
# We can remove this once we deprecate this way of running pipelines
return os.environ["WORKSPACE_TMP_PATH"] if "WORKSPACE_TMP_PATH" in os.environ else "/home/hexa/tmp"

def _get_local_connection_fields(self, env_variable_prefix: str):
connection_fields = {}
connection_type = os.getenv(env_variable_prefix).upper()

# Get fields for the connection type
_fields = fields(ConnectionClasses[connection_type])

if _fields:
for field in _fields:
env_var = f"{env_variable_prefix}_{field.name.upper()}"
connection_fields[field.name] = os.getenv(env_var)
else:
# custom connections
prefix = f"{env_variable_prefix}_"
connection_fields = {
key[len(prefix) :].lower(): val for key, val in os.environ.items() if key.startswith(prefix)
}

# need to map the correct name for s3 and postgres connection to ensure compatibility
# with the one coming from the API
if connection_type == "S3":
connection_fields.pop("secret_access_key")
connection_fields["access_key_secret"] = os.getenv(f"{env_variable_prefix}_ACCESS_KEY_SECRET")
if connection_type == "POSTGRESQL":
connection_fields.pop("database_name")
connection_fields["db_name"] = os.getenv(f"{env_variable_prefix}_DB_NAME")

return connection_fields

def get_connection(
self, identifier: str
) -> typing.Union[
DHIS2Connection,
PostgreSQLConnection,
IASOConnection,
S3Connection,
S3Connection,
GCSConnection,
CustomConnection,
None,
]:
Expand All @@ -181,7 +210,7 @@ def get_connection(
ValueError
If the connection does not exist
"""
fields = {}
connection_fields = {}
connection_type = None
if self._connected:
response = graphql(
Expand All @@ -203,45 +232,45 @@ def get_connection(
raise ValueError(f"Connection {identifier} does not exist.")

for d in data["fields"]:
fields[d.get("code")] = d.get("value")
connection_fields[d.get("code")] = d.get("value")

connection_type = data["type"]
connection_type = data["type"].upper()
else:
try:
env_variable_prefix = stringcase.constcase(identifier.lower())
for key, val in os.environ.items():
if key.startswith(f"{env_variable_prefix}_"):
field_name = key[len(f"{env_variable_prefix}_") :].lower()
fields[field_name] = val

connection_type = os.environ[f"{env_variable_prefix}"]
connection_type = os.environ[f"{env_variable_prefix}"].upper()
connection_fields = self._get_local_connection_fields(env_variable_prefix)
except KeyError:
raise ValueError

if not connection_type:
raise ValueError(f"Connection {identifier} does not exist.")

connection_type = connection_type.upper()
if connection_type in ConnectionClasses.keys():
if connection_type == "S3":
secret_access_key = fields.pop("access_key_secret")
return S3Connection(secret_access_key=secret_access_key, **fields)

if connection_type == "POSTGRESQL":
db_name = fields.pop("db_name")
port = int(fields.pop("port"))
return PostgreSQLConnection(database_name=db_name, port=port, **fields)

if connection_type == "CUSTOM":
dataclass = make_dataclass(
stringcase.pascalcase(identifier),
fields.keys(),
bases=(CustomConnection,),
repr=False,
)
return dataclass(**fields)

return ConnectionClasses[connection_type](**fields)
# In connected mode (API call) the secret_access_key field and db_name name are
# different from the offline ones
if connection_type == "S3":
secret_access_key = connection_fields.pop("access_key_secret")
return S3Connection(secret_access_key=secret_access_key, **connection_fields)

if connection_type == "POSTGRESQL":
db_name = connection_fields.pop("db_name")
port = int(connection_fields.pop("port"))
return PostgreSQLConnection(
database_name=db_name,
port=port,
**connection_fields,
)

if connection_type == "CUSTOM":
dataclass = make_dataclass(
stringcase.pascalcase(identifier),
connection_fields.keys(),
bases=(CustomConnection,),
repr=False,
)
return dataclass(**connection_fields)

return ConnectionClasses[connection_type](**connection_fields)

def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection:
"""Get a DHIS2 connection by its identifier.
Expand Down
44 changes: 40 additions & 4 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,36 @@ def test_workspace_dhis2_connection(self, workspace):
assert re.search("password", repr(dhis2_connection), re.IGNORECASE) is None
assert re.search("password", str(dhis2_connection), re.IGNORECASE) is None

def test_workspace_dhis2_connection_similar_prefix(self, workspace):
"""Base test case for DHIS2 connections."""
identifier = "polio"
identifier_2 = "polio-test"

env_variable_prefix = stringcase.constcase(identifier)
env_variable_prefix_2 = stringcase.constcase(identifier_2)

url = "https://test.dhis2.org/"
username = "dhis2"
password = "dhis2_pwd"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}": "dhis2",
f"{env_variable_prefix}_URL": url,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
f"{env_variable_prefix_2}": "dhis2",
f"{env_variable_prefix_2}_URL": "url_2",
f"{env_variable_prefix_2}_USERNAME": "username_2",
f"{env_variable_prefix_2}_PASSWORD": "password_2",
},
):
dhis2_connection = workspace.dhis2_connection(identifier=identifier)
assert dhis2_connection.url == url
assert dhis2_connection.username == username
assert dhis2_connection.password == password

def test_workspace_postgresql_connection_not_exist(self, workspace):
"""Does not exist test case for PostgreSQL connections."""
identifier = "polio-ff3a0d"
Expand Down Expand Up @@ -316,16 +346,21 @@ def test_workspace_get_connection(self, workspace):
"""Test get connection."""
data = {
"connectionBySlug": {
"type": "CUSTOM",
"fields": [{"code": "field_1", "value": "field_1_value"}],
"type": "S3",
"fields": [
{"code": "bucket_name", "value": "bucket_name"},
{"code": "access_key_id", "value": "access_key_id"},
{"code": "access_key_secret", "value": "secret_access_key"},
],
}
}

with mock.patch(
"openhexa.sdk.workspaces.current_workspace.graphql",
return_value=data,
):
connection = workspace.get_connection("random")
assert isinstance(connection, CustomConnection)
connection = workspace.get_connection("s3-connection")
assert isinstance(connection, S3Connection)

def test_workspace_dhis2_connection_not_exist(self, workspace):
"""Does not exist test case for DHIS2 connections."""
Expand Down Expand Up @@ -401,6 +436,7 @@ def test_workspace_s3_connection(monkeypatch, workspace):
data = S3Connection(access_key_id, secret_access_key, bucket_name)
with mock.patch.object(workspace, "get_connection", return_value=data):
s3_connection = workspace.s3_connection(identifier=identifier)

assert s3_connection.secret_access_key == secret_access_key
assert s3_connection.access_key_id == access_key_id
assert s3_connection.bucket_name == bucket_name
Expand Down

0 comments on commit 0bfd5f7

Please sign in to comment.