From fbe3f000f22f25ad3d67042180f1b54424629b24 Mon Sep 17 00:00:00 2001 From: cheikhgwane Date: Fri, 17 Nov 2023 13:47:52 +0000 Subject: [PATCH] refacto : inject connection on Parameter.validate --- openhexa/sdk/pipelines/parameter.py | 154 +++++++++------------------- openhexa/sdk/pipelines/pipeline.py | 8 +- openhexa/sdk/pipelines/utils.py | 30 ------ tests/test_parameter.py | 153 ++++++++++++++++++++------- tests/test_pipeline.py | 5 +- 5 files changed, 169 insertions(+), 181 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 1de2544..bb3379f 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -7,6 +7,7 @@ S3Connection, GCSConnection, ) +from openhexa.sdk.workspaces import workspace class ParameterValueError(Exception): @@ -55,6 +56,11 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) return value + def validate_default( + self, value: typing.Optional[typing.Any], allow_empty: bool = True + ) -> typing.Optional[typing.Any]: + return self.validate(value, allow_empty=allow_empty) + def __str__(self) -> str: return str(self.expected_type) @@ -132,15 +138,7 @@ def normalize(value: typing.Any) -> typing.Any: return value -class PostgreSQLConnectionType(ParameterType): - @property - def spec_type(self) -> str: - return "postgresql" - - @property - def expected_type(self) -> typing.Type: - return PostgreSQLConnectionType - +class ConnectionParameterType(ParameterType): @property def accepts_choice(self) -> bool: return False @@ -149,6 +147,17 @@ def accepts_choice(self) -> bool: def accepts_multiple(self) -> bool: return False + def validate_default( + self, value: typing.Optional[typing.Any], allow_empty: bool = True + ) -> typing.Optional[typing.Any]: + if value is None: + return + + if not isinstance(value, str): + raise InvalidParameterError("Default value for connection parameter type should be string.") + elif value == "": + raise ParameterValueError("Empty values are not accepted.") + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") @@ -156,37 +165,36 @@ def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = Tr if not isinstance(value, str): raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - return value - -class S3ConnectionType(ParameterType): +class PostgreSQLConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: - return "s3" + return "postgresql" @property def expected_type(self) -> typing.Type: - return S3ConnectionType + return PostgreSQLConnectionType + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + super().validate(value, allow_empty=allow_empty) + return workspace.postgresql_connection(value) + +class S3ConnectionType(ConnectionParameterType): @property - def accepts_choice(self) -> bool: - return False + def spec_type(self) -> str: + return "s3" @property - def accepts_multiple(self) -> bool: - return False + def expected_type(self) -> typing.Type: + return S3ConnectionType def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value + super().validate(value, allow_empty=allow_empty) + return workspace.s3_connection(value) -class GCSConnectionType(ParameterType): +class GCSConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "gcs" @@ -195,25 +203,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return GCSConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") + super().validate(value, allow_empty=allow_empty) + return workspace.gcs_connection(value) - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - return value - - -class DHIS2ConnectionType(ParameterType): +class DHIS2ConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "dhis2" @@ -222,25 +217,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return DHIS2ConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value + super().validate(value, allow_empty=allow_empty) + return workspace.dhis2_connection(value) -class IASOConnectionType(ParameterType): +class IASOConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "iaso" @@ -249,25 +231,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return IASOConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") + super().validate(value, allow_empty=allow_empty) + return workspace.iaso_connection(value) - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value - -class CustomConnectionType(ParameterType): +class CustomConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "custom" @@ -276,22 +245,9 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return str - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return super().validate(value, allow_empty) + super().validate(value, allow_empty=allow_empty) + return workspace.postgresql_connection(value) TYPES_BY_PYTHON_TYPE = { @@ -425,9 +381,9 @@ def _validate_default(self, default: typing.Any, multiple: bool): if not isinstance(default, list): raise InvalidParameterError("Default values should be lists when using multiple=True") for default_value in default: - self.type.validate(default_value, allow_empty=False) + self.type.validate_default(default_value, allow_empty=False) else: - self.type.validate(default, allow_empty=False) + self.type.validate_default(default, allow_empty=False) except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") @@ -522,19 +478,3 @@ def get_all_parameters(self): return [self.parameter, *self.function.get_all_parameters()] return [self.parameter] - - -def is_connection_parameter(param: Parameter): - return any( - [ - isinstance(param.type, type) - for type in [ - DHIS2ConnectionType, - PostgreSQLConnectionType, - IASOConnectionType, - S3ConnectionType, - GCSConnectionType, - CustomConnectionType, - ] - ] - ) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index d7511d8..0d2b32a 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -20,10 +20,9 @@ FunctionWithParameter, Parameter, ParameterValueError, - is_connection_parameter, ) from .task import PipelineWithTask -from .utils import get_local_workspace_config, get_connection_by_type +from .utils import get_local_workspace_config logger = getLogger(__name__) @@ -102,10 +101,7 @@ def run(self, config: typing.Dict[str, typing.Any]): for parameter in self.parameters: value = config.pop(parameter.code, None) validated_value = parameter.validate(value) - if is_connection_parameter(parameter): - validated_config[parameter.code] = get_connection_by_type(parameter.type, validated_value) - else: - validated_config[parameter.code] = validated_value + validated_config[parameter.code] = validated_value if len(config) > 0: raise ParameterValueError(f"The provided config contains invalid key(s): {', '.join(list(config.keys()))}") diff --git a/openhexa/sdk/pipelines/utils.py b/openhexa/sdk/pipelines/utils.py index 6936c7b..b431cc5 100644 --- a/openhexa/sdk/pipelines/utils.py +++ b/openhexa/sdk/pipelines/utils.py @@ -4,16 +4,6 @@ import stringcase import yaml -from openhexa.sdk.workspaces import workspace -from .parameter import ( - DHIS2ConnectionType, - PostgreSQLConnectionType, - IASOConnectionType, - S3ConnectionType, - GCSConnectionType, - CustomConnectionType, -) - class LocalWorkspaceConfigError(Exception): pass @@ -158,23 +148,3 @@ def get_local_workspace_config(path: Path): if key != "type": env_vars[stringcase.constcase(f"{slug}_{key.lower()}")] = str(value) return env_vars - - -def get_connection_by_type(type: any, identifier: str): - if isinstance(type, DHIS2ConnectionType): - return workspace.dhis2_connection(identifier) - - if isinstance(type, PostgreSQLConnectionType): - return workspace.postgresql_connection(identifier) - - if isinstance(type, IASOConnectionType): - return workspace.iaso_connection(identifier) - - if isinstance(type, S3ConnectionType): - return workspace.s3_connection(identifier) - - if isinstance(type, GCSConnectionType): - return workspace.gcs_connection(identifier) - - if isinstance(type, CustomConnectionType): - return workspace.custom_connection(identifier) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index c08ae25..411b954 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,4 +1,14 @@ import pytest +import os +import stringcase + +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + GCSConnection, + S3Connection, +) from openhexa.sdk.pipelines.parameter import ( Boolean, @@ -14,10 +24,11 @@ S3ConnectionType, IASOConnectionType, DHIS2ConnectionType, - CustomConnectionType, parameter, ) +from unittest import mock + def test_parameter_types_normalize(): # String @@ -70,41 +81,113 @@ def test_parameter_types_validate(): with pytest.raises(ParameterValueError): boolean_parameter_type.validate(86) - # PostgreSQL Connection - postgres_parameter_type = PostgreSQLConnectionType() - assert postgres_parameter_type.validate("postgres_connection_identifier") == "postgres_connection_identifier" - with pytest.raises(ParameterValueError): - postgres_parameter_type.validate(86) - - # IASO Connection - iaso_parameter_type = IASOConnectionType() - assert postgres_parameter_type.validate("iaso_connection_identifier") == "iaso_connection_identifier" - with pytest.raises(ParameterValueError): - iaso_parameter_type.validate(86) - - # GCS Connection - gcs_parameter_type = GCSConnectionType() - assert postgres_parameter_type.validate("gcs_connection_identifier") == "gcs_connection_identifier" - with pytest.raises(ParameterValueError): - gcs_parameter_type.validate(86) - - # S3 Connection - s3_parameter_type = S3ConnectionType() - assert s3_parameter_type.validate("s3_connection_identifier") == "s3_connection_identifier" - with pytest.raises(ParameterValueError): - s3_parameter_type.validate(86) - # DHIS2 Connection - dhsi2_parameter_type = DHIS2ConnectionType() - assert dhsi2_parameter_type.validate("dhis2_connection_identifier") == "dhis2_connection_identifier" - with pytest.raises(ParameterValueError): - dhsi2_parameter_type.validate(86) - - # Custom Connection - custom_parameter_type = CustomConnectionType() - assert custom_parameter_type.validate("custom_connection_identifier") == "custom_connection_identifier" - with pytest.raises(ParameterValueError): - custom_parameter_type.validate(86) +def test_validate_postgres_connection(): + identifier = "polio-ff3a0d" + env_variable_prefix = stringcase.constcase(identifier) + host = "https://172.17.0.1" + port = "5432" + username = "dhis2" + password = "dhis2_pwd" + database_name = "polio" + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_HOST": host, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix}_PORT": port, + f"{env_variable_prefix}_DB_NAME": database_name, + }, + ): + postgres_parameter_type = PostgreSQLConnectionType() + assert postgres_parameter_type.validate(identifier) == PostgreSQLConnection( + host, int(port), username, password, database_name + ) + with pytest.raises(ParameterValueError): + postgres_parameter_type.validate(86) + + +def test_validate_dhis2_connection(): + identifier = "dhis2-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + dhsi2_parameter_type = DHIS2ConnectionType() + assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password) + with pytest.raises(ParameterValueError): + dhsi2_parameter_type.validate(86) + + +def test_validate_iaso_connection(): + identifier = "iaso-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.iaso.org/" + username = "iaso" + password = "iaso_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + iaso_parameter_type = IASOConnectionType() + assert iaso_parameter_type.validate(identifier) == IASOConnection(url, username, password) + with pytest.raises(ParameterValueError): + iaso_parameter_type.validate(86) + + +def test_validate_gcs_connection(): + identifier = "gcs-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + service_account_key = "HqQBxH0BAI3zF7kANUNlGg" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + gcs_parameter_type = GCSConnectionType() + assert gcs_parameter_type.validate(identifier) == GCSConnection(service_account_key, bucket_name) + with pytest.raises(ParameterValueError): + gcs_parameter_type.validate(86) + + +def test_validate_s3_connection(): + identifier = "s3-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" + access_key_id = "84hVntMaMSYP/RSW9ex04w" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key, + f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + s3_parameter_type = S3ConnectionType() + assert s3_parameter_type.validate(identifier) == S3Connection(access_key_id, secret_access_key, bucket_name) + with pytest.raises(ParameterValueError): + s3_parameter_type.validate(86) def test_parameter_init(): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7105f06..ebb96b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,15 +4,14 @@ import stringcase import os -from openhexa.sdk.pipelines.parameter import ( - Parameter, - ParameterValueError, +from openhexa.sdk.workspaces.connection import ( DHIS2Connection, IASOConnection, PostgreSQLConnection, GCSConnection, S3Connection, ) +from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError from openhexa.sdk.pipelines.pipeline import Pipeline