From 9324b6563ab772d87f25f7a050ecbdd3840ef0c2 Mon Sep 17 00:00:00 2001 From: cheikhgwane Date: Wed, 8 Nov 2023 13:13:04 +0000 Subject: [PATCH 1/5] feat : add parameter of type connection (IASO,GCS..) --- openhexa/sdk/pipelines/parameter.py | 109 +++++++++++++++++++++++++++- tests/test_parameter.py | 42 +++++++++++ 2 files changed, 150 insertions(+), 1 deletion(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 8b2b199..da17459 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -125,7 +125,114 @@ def normalize(value: typing.Any) -> typing.Any: return value -TYPES_BY_PYTHON_TYPE = {str: String, bool: Boolean, int: Integer, float: Float} +class PostgreSQLConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "postgresql" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +class S3ConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "s3" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +class GCSConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "gcs" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +class DHIS2ConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "dhis2" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +class IASOConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "iaso" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +class CustomConnectionType(ParameterType): + @property + def spec_type(self) -> str: + return "custom" + + @property + def expected_type(self) -> typing.Type: + return str + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + if not allow_empty and value == "": + raise ParameterValueError("Empty values are not accepted.") + + return super().validate(value, allow_empty) + + +TYPES_BY_PYTHON_TYPE = { + str: String, + bool: Boolean, + int: Integer, + float: Float, + "dhis2": DHIS2ConnectionType, + "postgresql": PostgreSQLConnectionType, + "iaso": IASOConnectionType, + "s3": S3ConnectionType, + "gcs": GCSConnectionType, + "custom": CustomConnectionType, +} class InvalidParameterError(Exception): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index d382efc..c08ae25 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -9,6 +9,12 @@ Parameter, ParameterValueError, String, + PostgreSQLConnectionType, + GCSConnectionType, + S3ConnectionType, + IASOConnectionType, + DHIS2ConnectionType, + CustomConnectionType, parameter, ) @@ -64,6 +70,42 @@ def test_parameter_types_validate(): with pytest.raises(ParameterValueError): boolean_parameter_type.validate(86) + # PostgreSQL Connection + postgres_parameter_type = PostgreSQLConnectionType() + assert postgres_parameter_type.validate("postgres_connection_identifier") == "postgres_connection_identifier" + with pytest.raises(ParameterValueError): + postgres_parameter_type.validate(86) + + # IASO Connection + iaso_parameter_type = IASOConnectionType() + assert postgres_parameter_type.validate("iaso_connection_identifier") == "iaso_connection_identifier" + with pytest.raises(ParameterValueError): + iaso_parameter_type.validate(86) + + # GCS Connection + gcs_parameter_type = GCSConnectionType() + assert postgres_parameter_type.validate("gcs_connection_identifier") == "gcs_connection_identifier" + with pytest.raises(ParameterValueError): + gcs_parameter_type.validate(86) + + # S3 Connection + s3_parameter_type = S3ConnectionType() + assert s3_parameter_type.validate("s3_connection_identifier") == "s3_connection_identifier" + with pytest.raises(ParameterValueError): + s3_parameter_type.validate(86) + + # DHIS2 Connection + dhsi2_parameter_type = DHIS2ConnectionType() + assert dhsi2_parameter_type.validate("dhis2_connection_identifier") == "dhis2_connection_identifier" + with pytest.raises(ParameterValueError): + dhsi2_parameter_type.validate(86) + + # Custom Connection + custom_parameter_type = CustomConnectionType() + assert custom_parameter_type.validate("custom_connection_identifier") == "custom_connection_identifier" + with pytest.raises(ParameterValueError): + custom_parameter_type.validate(86) + def test_parameter_init(): # Wrong type From 34e617cbdb1bc55bce8d69c69730b959f392d6bf Mon Sep 17 00:00:00 2001 From: cheikhgwane Date: Wed, 15 Nov 2023 10:42:23 +0000 Subject: [PATCH 2/5] feat : inject connection instance instead of identifier --- openhexa/sdk/pipelines/parameter.py | 120 +++++++++++++++++++---- openhexa/sdk/pipelines/pipeline.py | 14 ++- openhexa/sdk/pipelines/utils.py | 30 ++++++ tests/test_pipeline.py | 143 +++++++++++++++++++++++++++- 4 files changed, 286 insertions(+), 21 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index da17459..1de2544 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -1,5 +1,12 @@ import re import typing +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + S3Connection, + GCSConnection, +) class ParameterValueError(Exception): @@ -132,13 +139,24 @@ def spec_type(self) -> str: @property def expected_type(self) -> typing.Type: - return str + return PostgreSQLConnectionType + + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + return value class S3ConnectionType(ParameterType): @@ -148,13 +166,24 @@ def spec_type(self) -> str: @property def expected_type(self) -> typing.Type: - return str + return S3ConnectionType + + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + return value class GCSConnectionType(ParameterType): @@ -164,13 +193,24 @@ def spec_type(self) -> str: @property def expected_type(self) -> typing.Type: - return str + return GCSConnectionType + + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + return value class DHIS2ConnectionType(ParameterType): @@ -180,13 +220,24 @@ def spec_type(self) -> str: @property def expected_type(self) -> typing.Type: - return str + return DHIS2ConnectionType + + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + return value class IASOConnectionType(ParameterType): @@ -196,13 +247,24 @@ def spec_type(self) -> str: @property def expected_type(self) -> typing.Type: - return str + return IASOConnectionType + + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + return value class CustomConnectionType(ParameterType): @@ -214,10 +276,21 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return str + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + return super().validate(value, allow_empty) @@ -226,12 +299,11 @@ def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = Tr bool: Boolean, int: Integer, float: Float, - "dhis2": DHIS2ConnectionType, - "postgresql": PostgreSQLConnectionType, - "iaso": IASOConnectionType, - "s3": S3ConnectionType, - "gcs": GCSConnectionType, - "custom": CustomConnectionType, + DHIS2Connection: DHIS2ConnectionType, + PostgreSQLConnection: PostgreSQLConnectionType, + IASOConnection: IASOConnectionType, + S3Connection: S3ConnectionType, + GCSConnection: GCSConnectionType, } @@ -450,3 +522,19 @@ def get_all_parameters(self): return [self.parameter, *self.function.get_all_parameters()] return [self.parameter] + + +def is_connection_parameter(param: Parameter): + return any( + [ + isinstance(param.type, type) + for type in [ + DHIS2ConnectionType, + PostgreSQLConnectionType, + IASOConnectionType, + S3ConnectionType, + GCSConnectionType, + CustomConnectionType, + ] + ] + ) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index 622016a..d7511d8 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -16,9 +16,14 @@ from openhexa.sdk.utils import Environments, get_environment -from .parameter import FunctionWithParameter, Parameter, ParameterValueError +from .parameter import ( + FunctionWithParameter, + Parameter, + ParameterValueError, + is_connection_parameter, +) from .task import PipelineWithTask -from .utils import get_local_workspace_config +from .utils import get_local_workspace_config, get_connection_by_type logger = getLogger(__name__) @@ -97,7 +102,10 @@ def run(self, config: typing.Dict[str, typing.Any]): for parameter in self.parameters: value = config.pop(parameter.code, None) validated_value = parameter.validate(value) - validated_config[parameter.code] = validated_value + if is_connection_parameter(parameter): + validated_config[parameter.code] = get_connection_by_type(parameter.type, validated_value) + else: + validated_config[parameter.code] = validated_value if len(config) > 0: raise ParameterValueError(f"The provided config contains invalid key(s): {', '.join(list(config.keys()))}") diff --git a/openhexa/sdk/pipelines/utils.py b/openhexa/sdk/pipelines/utils.py index b431cc5..6936c7b 100644 --- a/openhexa/sdk/pipelines/utils.py +++ b/openhexa/sdk/pipelines/utils.py @@ -4,6 +4,16 @@ import stringcase import yaml +from openhexa.sdk.workspaces import workspace +from .parameter import ( + DHIS2ConnectionType, + PostgreSQLConnectionType, + IASOConnectionType, + S3ConnectionType, + GCSConnectionType, + CustomConnectionType, +) + class LocalWorkspaceConfigError(Exception): pass @@ -148,3 +158,23 @@ def get_local_workspace_config(path: Path): if key != "type": env_vars[stringcase.constcase(f"{slug}_{key.lower()}")] = str(value) return env_vars + + +def get_connection_by_type(type: any, identifier: str): + if isinstance(type, DHIS2ConnectionType): + return workspace.dhis2_connection(identifier) + + if isinstance(type, PostgreSQLConnectionType): + return workspace.postgresql_connection(identifier) + + if isinstance(type, IASOConnectionType): + return workspace.iaso_connection(identifier) + + if isinstance(type, S3ConnectionType): + return workspace.s3_connection(identifier) + + if isinstance(type, GCSConnectionType): + return workspace.gcs_connection(identifier) + + if isinstance(type, CustomConnectionType): + return workspace.custom_connection(identifier) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 15f54a7..7105f06 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,8 +1,18 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest +import stringcase +import os -from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError +from openhexa.sdk.pipelines.parameter import ( + Parameter, + ParameterValueError, + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + GCSConnection, + S3Connection, +) from openhexa.sdk.pipelines.pipeline import Pipeline @@ -34,6 +44,135 @@ def test_pipeline_run_extra_config(): pipeline.run({"arg1": "ok", "arg2": "extra"}) +def test_pipeline_run_connection_dhis2_parameter_config(): + identifier = "dhis2-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter( + "connection_param", name="this is a test for connection parameter", type=DHIS2Connection + ) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=DHIS2Connection(url, username, password)) + + +def test_pipeline_run_connection_iaso_parameter_config(): + identifier = "iaso-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.iaso.org/" + username = "iaso" + password = "iaso_pwd" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=IASOConnection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=IASOConnection(url, username, password)) + + +def test_pipeline_run_connection_gcs_parameter_config(): + identifier = "gcs-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + service_account_key = "HqQBxH0BAI3zF7kANUNlGg" + bucket_name = "test" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=GCSConnection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=GCSConnection(service_account_key, bucket_name)) + + +def test_pipeline_run_connection_s3_parameter_config(): + identifier = "s3-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" + access_key_id = "84hVntMaMSYP/RSW9ex04w" + bucket_name = "test" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key, + f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=S3Connection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with( + connection_param=S3Connection(access_key_id, secret_access_key, bucket_name) + ) + + +def test_pipeline_run_connection_postgres_parameter_config(): + identifier = "postgres-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + host = "https://127.0.0.1" + port = "5432" + username = "hexa_sdk" + password = "hexa_sdk_pwd" + database_name = "hexa_sdk" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_HOST": host, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix}_PORT": port, + f"{env_variable_prefix}_DB_NAME": database_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter( + "connection_param", name="this is a test for connection parameter", type=PostgreSQLConnection + ) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with( + connection_param=PostgreSQLConnection(host, int(port), username, password, database_name) + ) + + def test_pipeline_parameters_spec(): pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str) From fbe3f000f22f25ad3d67042180f1b54424629b24 Mon Sep 17 00:00:00 2001 From: cheikhgwane Date: Fri, 17 Nov 2023 13:47:52 +0000 Subject: [PATCH 3/5] refacto : inject connection on Parameter.validate --- openhexa/sdk/pipelines/parameter.py | 154 +++++++++------------------- openhexa/sdk/pipelines/pipeline.py | 8 +- openhexa/sdk/pipelines/utils.py | 30 ------ tests/test_parameter.py | 153 ++++++++++++++++++++------- tests/test_pipeline.py | 5 +- 5 files changed, 169 insertions(+), 181 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 1de2544..bb3379f 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -7,6 +7,7 @@ S3Connection, GCSConnection, ) +from openhexa.sdk.workspaces import workspace class ParameterValueError(Exception): @@ -55,6 +56,11 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) return value + def validate_default( + self, value: typing.Optional[typing.Any], allow_empty: bool = True + ) -> typing.Optional[typing.Any]: + return self.validate(value, allow_empty=allow_empty) + def __str__(self) -> str: return str(self.expected_type) @@ -132,15 +138,7 @@ def normalize(value: typing.Any) -> typing.Any: return value -class PostgreSQLConnectionType(ParameterType): - @property - def spec_type(self) -> str: - return "postgresql" - - @property - def expected_type(self) -> typing.Type: - return PostgreSQLConnectionType - +class ConnectionParameterType(ParameterType): @property def accepts_choice(self) -> bool: return False @@ -149,6 +147,17 @@ def accepts_choice(self) -> bool: def accepts_multiple(self) -> bool: return False + def validate_default( + self, value: typing.Optional[typing.Any], allow_empty: bool = True + ) -> typing.Optional[typing.Any]: + if value is None: + return + + if not isinstance(value, str): + raise InvalidParameterError("Default value for connection parameter type should be string.") + elif value == "": + raise ParameterValueError("Empty values are not accepted.") + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: if not allow_empty and value == "": raise ParameterValueError("Empty values are not accepted.") @@ -156,37 +165,36 @@ def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = Tr if not isinstance(value, str): raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - return value - -class S3ConnectionType(ParameterType): +class PostgreSQLConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: - return "s3" + return "postgresql" @property def expected_type(self) -> typing.Type: - return S3ConnectionType + return PostgreSQLConnectionType + + def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: + super().validate(value, allow_empty=allow_empty) + return workspace.postgresql_connection(value) + +class S3ConnectionType(ConnectionParameterType): @property - def accepts_choice(self) -> bool: - return False + def spec_type(self) -> str: + return "s3" @property - def accepts_multiple(self) -> bool: - return False + def expected_type(self) -> typing.Type: + return S3ConnectionType def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value + super().validate(value, allow_empty=allow_empty) + return workspace.s3_connection(value) -class GCSConnectionType(ParameterType): +class GCSConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "gcs" @@ -195,25 +203,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return GCSConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") + super().validate(value, allow_empty=allow_empty) + return workspace.gcs_connection(value) - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - return value - - -class DHIS2ConnectionType(ParameterType): +class DHIS2ConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "dhis2" @@ -222,25 +217,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return DHIS2ConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value + super().validate(value, allow_empty=allow_empty) + return workspace.dhis2_connection(value) -class IASOConnectionType(ParameterType): +class IASOConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "iaso" @@ -249,25 +231,12 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return IASOConnectionType - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") + super().validate(value, allow_empty=allow_empty) + return workspace.iaso_connection(value) - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return value - -class CustomConnectionType(ParameterType): +class CustomConnectionType(ConnectionParameterType): @property def spec_type(self) -> str: return "custom" @@ -276,22 +245,9 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return str - @property - def accepts_choice(self) -> bool: - return False - - @property - def accepts_multiple(self) -> bool: - return False - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - - if not isinstance(value, str): - raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") - - return super().validate(value, allow_empty) + super().validate(value, allow_empty=allow_empty) + return workspace.postgresql_connection(value) TYPES_BY_PYTHON_TYPE = { @@ -425,9 +381,9 @@ def _validate_default(self, default: typing.Any, multiple: bool): if not isinstance(default, list): raise InvalidParameterError("Default values should be lists when using multiple=True") for default_value in default: - self.type.validate(default_value, allow_empty=False) + self.type.validate_default(default_value, allow_empty=False) else: - self.type.validate(default, allow_empty=False) + self.type.validate_default(default, allow_empty=False) except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") @@ -522,19 +478,3 @@ def get_all_parameters(self): return [self.parameter, *self.function.get_all_parameters()] return [self.parameter] - - -def is_connection_parameter(param: Parameter): - return any( - [ - isinstance(param.type, type) - for type in [ - DHIS2ConnectionType, - PostgreSQLConnectionType, - IASOConnectionType, - S3ConnectionType, - GCSConnectionType, - CustomConnectionType, - ] - ] - ) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index d7511d8..0d2b32a 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -20,10 +20,9 @@ FunctionWithParameter, Parameter, ParameterValueError, - is_connection_parameter, ) from .task import PipelineWithTask -from .utils import get_local_workspace_config, get_connection_by_type +from .utils import get_local_workspace_config logger = getLogger(__name__) @@ -102,10 +101,7 @@ def run(self, config: typing.Dict[str, typing.Any]): for parameter in self.parameters: value = config.pop(parameter.code, None) validated_value = parameter.validate(value) - if is_connection_parameter(parameter): - validated_config[parameter.code] = get_connection_by_type(parameter.type, validated_value) - else: - validated_config[parameter.code] = validated_value + validated_config[parameter.code] = validated_value if len(config) > 0: raise ParameterValueError(f"The provided config contains invalid key(s): {', '.join(list(config.keys()))}") diff --git a/openhexa/sdk/pipelines/utils.py b/openhexa/sdk/pipelines/utils.py index 6936c7b..b431cc5 100644 --- a/openhexa/sdk/pipelines/utils.py +++ b/openhexa/sdk/pipelines/utils.py @@ -4,16 +4,6 @@ import stringcase import yaml -from openhexa.sdk.workspaces import workspace -from .parameter import ( - DHIS2ConnectionType, - PostgreSQLConnectionType, - IASOConnectionType, - S3ConnectionType, - GCSConnectionType, - CustomConnectionType, -) - class LocalWorkspaceConfigError(Exception): pass @@ -158,23 +148,3 @@ def get_local_workspace_config(path: Path): if key != "type": env_vars[stringcase.constcase(f"{slug}_{key.lower()}")] = str(value) return env_vars - - -def get_connection_by_type(type: any, identifier: str): - if isinstance(type, DHIS2ConnectionType): - return workspace.dhis2_connection(identifier) - - if isinstance(type, PostgreSQLConnectionType): - return workspace.postgresql_connection(identifier) - - if isinstance(type, IASOConnectionType): - return workspace.iaso_connection(identifier) - - if isinstance(type, S3ConnectionType): - return workspace.s3_connection(identifier) - - if isinstance(type, GCSConnectionType): - return workspace.gcs_connection(identifier) - - if isinstance(type, CustomConnectionType): - return workspace.custom_connection(identifier) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index c08ae25..411b954 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,4 +1,14 @@ import pytest +import os +import stringcase + +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + GCSConnection, + S3Connection, +) from openhexa.sdk.pipelines.parameter import ( Boolean, @@ -14,10 +24,11 @@ S3ConnectionType, IASOConnectionType, DHIS2ConnectionType, - CustomConnectionType, parameter, ) +from unittest import mock + def test_parameter_types_normalize(): # String @@ -70,41 +81,113 @@ def test_parameter_types_validate(): with pytest.raises(ParameterValueError): boolean_parameter_type.validate(86) - # PostgreSQL Connection - postgres_parameter_type = PostgreSQLConnectionType() - assert postgres_parameter_type.validate("postgres_connection_identifier") == "postgres_connection_identifier" - with pytest.raises(ParameterValueError): - postgres_parameter_type.validate(86) - - # IASO Connection - iaso_parameter_type = IASOConnectionType() - assert postgres_parameter_type.validate("iaso_connection_identifier") == "iaso_connection_identifier" - with pytest.raises(ParameterValueError): - iaso_parameter_type.validate(86) - - # GCS Connection - gcs_parameter_type = GCSConnectionType() - assert postgres_parameter_type.validate("gcs_connection_identifier") == "gcs_connection_identifier" - with pytest.raises(ParameterValueError): - gcs_parameter_type.validate(86) - - # S3 Connection - s3_parameter_type = S3ConnectionType() - assert s3_parameter_type.validate("s3_connection_identifier") == "s3_connection_identifier" - with pytest.raises(ParameterValueError): - s3_parameter_type.validate(86) - # DHIS2 Connection - dhsi2_parameter_type = DHIS2ConnectionType() - assert dhsi2_parameter_type.validate("dhis2_connection_identifier") == "dhis2_connection_identifier" - with pytest.raises(ParameterValueError): - dhsi2_parameter_type.validate(86) - - # Custom Connection - custom_parameter_type = CustomConnectionType() - assert custom_parameter_type.validate("custom_connection_identifier") == "custom_connection_identifier" - with pytest.raises(ParameterValueError): - custom_parameter_type.validate(86) +def test_validate_postgres_connection(): + identifier = "polio-ff3a0d" + env_variable_prefix = stringcase.constcase(identifier) + host = "https://172.17.0.1" + port = "5432" + username = "dhis2" + password = "dhis2_pwd" + database_name = "polio" + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_HOST": host, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix}_PORT": port, + f"{env_variable_prefix}_DB_NAME": database_name, + }, + ): + postgres_parameter_type = PostgreSQLConnectionType() + assert postgres_parameter_type.validate(identifier) == PostgreSQLConnection( + host, int(port), username, password, database_name + ) + with pytest.raises(ParameterValueError): + postgres_parameter_type.validate(86) + + +def test_validate_dhis2_connection(): + identifier = "dhis2-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + dhsi2_parameter_type = DHIS2ConnectionType() + assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password) + with pytest.raises(ParameterValueError): + dhsi2_parameter_type.validate(86) + + +def test_validate_iaso_connection(): + identifier = "iaso-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.iaso.org/" + username = "iaso" + password = "iaso_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + iaso_parameter_type = IASOConnectionType() + assert iaso_parameter_type.validate(identifier) == IASOConnection(url, username, password) + with pytest.raises(ParameterValueError): + iaso_parameter_type.validate(86) + + +def test_validate_gcs_connection(): + identifier = "gcs-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + service_account_key = "HqQBxH0BAI3zF7kANUNlGg" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + gcs_parameter_type = GCSConnectionType() + assert gcs_parameter_type.validate(identifier) == GCSConnection(service_account_key, bucket_name) + with pytest.raises(ParameterValueError): + gcs_parameter_type.validate(86) + + +def test_validate_s3_connection(): + identifier = "s3-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" + access_key_id = "84hVntMaMSYP/RSW9ex04w" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key, + f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + s3_parameter_type = S3ConnectionType() + assert s3_parameter_type.validate(identifier) == S3Connection(access_key_id, secret_access_key, bucket_name) + with pytest.raises(ParameterValueError): + s3_parameter_type.validate(86) def test_parameter_init(): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7105f06..ebb96b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,15 +4,14 @@ import stringcase import os -from openhexa.sdk.pipelines.parameter import ( - Parameter, - ParameterValueError, +from openhexa.sdk.workspaces.connection import ( DHIS2Connection, IASOConnection, PostgreSQLConnection, GCSConnection, S3Connection, ) +from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError from openhexa.sdk.pipelines.pipeline import Pipeline From f3946d63fc2a6d428ec6dd50b7ecdc66103bb6e4 Mon Sep 17 00:00:00 2001 From: cheikhgwane Date: Fri, 17 Nov 2023 14:06:13 +0000 Subject: [PATCH 4/5] chore : rename supported types --- openhexa/sdk/pipelines/parameter.py | 4 ++-- tests/test_parameter.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index bb3379f..6a1da01 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -65,7 +65,7 @@ def __str__(self) -> str: return str(self.expected_type) -class String(ParameterType): +class StringType(ParameterType): @property def spec_type(self) -> str: return "str" @@ -251,7 +251,7 @@ def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = Tr TYPES_BY_PYTHON_TYPE = { - str: String, + str: StringType, bool: Boolean, int: Integer, float: Float, diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 411b954..2dc4ac6 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -18,7 +18,7 @@ InvalidParameterError, Parameter, ParameterValueError, - String, + StringType, PostgreSQLConnectionType, GCSConnectionType, S3ConnectionType, @@ -31,8 +31,8 @@ def test_parameter_types_normalize(): - # String - string_parameter_type = String() + # StringType + string_parameter_type = StringType() assert string_parameter_type.normalize("a string") == "a string" assert string_parameter_type.normalize(" a string ") == "a string" assert string_parameter_type.normalize("") is None @@ -56,8 +56,8 @@ def test_parameter_types_normalize(): def test_parameter_types_validate(): - # String - string_parameter_type = String() + # StringType + string_parameter_type = StringType() assert string_parameter_type.validate("a string") == "a string" with pytest.raises(ParameterValueError): string_parameter_type.validate(86) @@ -353,7 +353,7 @@ def a_function(): assert function_parameters[0].multiple is False assert function_parameters[1].code == "arg2" - assert isinstance(function_parameters[1].type, String) + assert isinstance(function_parameters[1].type, StringType) assert function_parameters[1].name == "Arg 2" assert function_parameters[1].help == "Help 2" assert function_parameters[1].default == ["yo"] From d263c2c568c982ca7a12aad599c06cd59b3da4a8 Mon Sep 17 00:00:00 2001 From: pvanliefland Date: Fri, 17 Nov 2023 17:18:51 +0100 Subject: [PATCH 5/5] refactor: remove allow_empty, share validate() between connection parameter types --- openhexa/sdk/pipelines/parameter.py | 54 +++++++++++++---------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 6a1da01..49e4716 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -8,6 +8,7 @@ GCSConnection, ) from openhexa.sdk.workspaces import workspace +from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist class ParameterValueError(Exception): @@ -46,7 +47,7 @@ def normalize(value: typing.Any) -> typing.Any: return value - def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) -> typing.Optional[typing.Any]: + def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]: """Validate the provided value for this type.""" if not isinstance(value, self.expected_type): @@ -56,10 +57,8 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) return value - def validate_default( - self, value: typing.Optional[typing.Any], allow_empty: bool = True - ) -> typing.Optional[typing.Any]: - return self.validate(value, allow_empty=allow_empty) + def validate_default(self, value: typing.Optional[typing.Any]): + self.validate(value) def __str__(self) -> str: return str(self.expected_type) @@ -86,11 +85,11 @@ def normalize(value: typing.Any) -> typing.Optional[str]: return normalized_value - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": + def validate_default(self, value: typing.Optional[typing.Any]): + if value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + super().validate_default(value) class Boolean(ParameterType): @@ -147,9 +146,7 @@ def accepts_choice(self) -> bool: def accepts_multiple(self) -> bool: return False - def validate_default( - self, value: typing.Optional[typing.Any], allow_empty: bool = True - ) -> typing.Optional[typing.Any]: + def validate_default(self, value: typing.Optional[typing.Any]): if value is None: return @@ -158,13 +155,18 @@ def validate_default( elif value == "": raise ParameterValueError("Empty values are not accepted.") - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": - raise ParameterValueError("Empty values are not accepted.") - + def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]: if not isinstance(value, str): raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + try: + return self.to_connection(value) + except ConnectionDoesNotExist as e: + raise ParameterValueError(str(e)) + + def to_connection(self, value: str) -> typing.Any: + raise NotImplementedError + class PostgreSQLConnectionType(ConnectionParameterType): @property @@ -175,8 +177,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return PostgreSQLConnectionType - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.postgresql_connection(value) @@ -189,8 +190,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return S3ConnectionType - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.s3_connection(value) @@ -203,8 +203,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return GCSConnectionType - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.gcs_connection(value) @@ -217,8 +216,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return DHIS2ConnectionType - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.dhis2_connection(value) @@ -231,8 +229,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return IASOConnectionType - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.iaso_connection(value) @@ -245,8 +242,7 @@ def spec_type(self) -> str: def expected_type(self) -> typing.Type: return str - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - super().validate(value, allow_empty=allow_empty) + def to_connection(self, value: str) -> typing.Any: return workspace.postgresql_connection(value) @@ -381,9 +377,9 @@ def _validate_default(self, default: typing.Any, multiple: bool): if not isinstance(default, list): raise InvalidParameterError("Default values should be lists when using multiple=True") for default_value in default: - self.type.validate_default(default_value, allow_empty=False) + self.type.validate_default(default_value) else: - self.type.validate_default(default, allow_empty=False) + self.type.validate_default(default) except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.")