From 111586443d285dccdc1979f9db4c4276f5e7d607 Mon Sep 17 00:00:00 2001 From: Cheikh Gueye Wane Date: Fri, 17 Nov 2023 16:27:09 +0000 Subject: [PATCH] feat : add parameter of type connection (IASO,GCS..) (#86) --- openhexa/sdk/pipelines/parameter.py | 147 ++++++++++++++++++++++++++-- openhexa/sdk/pipelines/pipeline.py | 6 +- tests/test_parameter.py | 137 ++++++++++++++++++++++++-- tests/test_pipeline.py | 140 +++++++++++++++++++++++++- 4 files changed, 414 insertions(+), 16 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 8b2b199..49e4716 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -1,5 +1,14 @@ import re import typing +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + S3Connection, + GCSConnection, +) +from openhexa.sdk.workspaces import workspace +from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist class ParameterValueError(Exception): @@ -38,7 +47,7 @@ def normalize(value: typing.Any) -> typing.Any: return value - def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) -> typing.Optional[typing.Any]: + def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]: """Validate the provided value for this type.""" if not isinstance(value, self.expected_type): @@ -48,11 +57,14 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) return value + def validate_default(self, value: typing.Optional[typing.Any]): + self.validate(value) + def __str__(self) -> str: return str(self.expected_type) -class String(ParameterType): +class StringType(ParameterType): @property def spec_type(self) -> str: return "str" @@ -73,11 +85,11 @@ def normalize(value: typing.Any) -> typing.Optional[str]: return normalized_value - def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]: - if not allow_empty and value == "": + def validate_default(self, value: typing.Optional[typing.Any]): + if value == "": raise ParameterValueError("Empty values are not accepted.") - return super().validate(value, allow_empty) + super().validate_default(value) class Boolean(ParameterType): @@ -125,7 +137,126 @@ def normalize(value: typing.Any) -> typing.Any: return value -TYPES_BY_PYTHON_TYPE = {str: String, bool: Boolean, int: Integer, float: Float} +class ConnectionParameterType(ParameterType): + @property + def accepts_choice(self) -> bool: + return False + + @property + def accepts_multiple(self) -> bool: + return False + + def validate_default(self, value: typing.Optional[typing.Any]): + if value is None: + return + + if not isinstance(value, str): + raise InvalidParameterError("Default value for connection parameter type should be string.") + elif value == "": + raise ParameterValueError("Empty values are not accepted.") + + def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]: + if not isinstance(value, str): + raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") + + try: + return self.to_connection(value) + except ConnectionDoesNotExist as e: + raise ParameterValueError(str(e)) + + def to_connection(self, value: str) -> typing.Any: + raise NotImplementedError + + +class PostgreSQLConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "postgresql" + + @property + def expected_type(self) -> typing.Type: + return PostgreSQLConnectionType + + def to_connection(self, value: str) -> typing.Any: + return workspace.postgresql_connection(value) + + +class S3ConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "s3" + + @property + def expected_type(self) -> typing.Type: + return S3ConnectionType + + def to_connection(self, value: str) -> typing.Any: + return workspace.s3_connection(value) + + +class GCSConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "gcs" + + @property + def expected_type(self) -> typing.Type: + return GCSConnectionType + + def to_connection(self, value: str) -> typing.Any: + return workspace.gcs_connection(value) + + +class DHIS2ConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "dhis2" + + @property + def expected_type(self) -> typing.Type: + return DHIS2ConnectionType + + def to_connection(self, value: str) -> typing.Any: + return workspace.dhis2_connection(value) + + +class IASOConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "iaso" + + @property + def expected_type(self) -> typing.Type: + return IASOConnectionType + + def to_connection(self, value: str) -> typing.Any: + return workspace.iaso_connection(value) + + +class CustomConnectionType(ConnectionParameterType): + @property + def spec_type(self) -> str: + return "custom" + + @property + def expected_type(self) -> typing.Type: + return str + + def to_connection(self, value: str) -> typing.Any: + return workspace.postgresql_connection(value) + + +TYPES_BY_PYTHON_TYPE = { + str: StringType, + bool: Boolean, + int: Integer, + float: Float, + DHIS2Connection: DHIS2ConnectionType, + PostgreSQLConnection: PostgreSQLConnectionType, + IASOConnection: IASOConnectionType, + S3Connection: S3ConnectionType, + GCSConnection: GCSConnectionType, +} class InvalidParameterError(Exception): @@ -246,9 +377,9 @@ def _validate_default(self, default: typing.Any, multiple: bool): if not isinstance(default, list): raise InvalidParameterError("Default values should be lists when using multiple=True") for default_value in default: - self.type.validate(default_value, allow_empty=False) + self.type.validate_default(default_value) else: - self.type.validate(default, allow_empty=False) + self.type.validate_default(default) except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index 622016a..0d2b32a 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -16,7 +16,11 @@ from openhexa.sdk.utils import Environments, get_environment -from .parameter import FunctionWithParameter, Parameter, ParameterValueError +from .parameter import ( + FunctionWithParameter, + Parameter, + ParameterValueError, +) from .task import PipelineWithTask from .utils import get_local_workspace_config diff --git a/tests/test_parameter.py b/tests/test_parameter.py index d382efc..2dc4ac6 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,4 +1,14 @@ import pytest +import os +import stringcase + +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + GCSConnection, + S3Connection, +) from openhexa.sdk.pipelines.parameter import ( Boolean, @@ -8,14 +18,21 @@ InvalidParameterError, Parameter, ParameterValueError, - String, + StringType, + PostgreSQLConnectionType, + GCSConnectionType, + S3ConnectionType, + IASOConnectionType, + DHIS2ConnectionType, parameter, ) +from unittest import mock + def test_parameter_types_normalize(): - # String - string_parameter_type = String() + # StringType + string_parameter_type = StringType() assert string_parameter_type.normalize("a string") == "a string" assert string_parameter_type.normalize(" a string ") == "a string" assert string_parameter_type.normalize("") is None @@ -39,8 +56,8 @@ def test_parameter_types_normalize(): def test_parameter_types_validate(): - # String - string_parameter_type = String() + # StringType + string_parameter_type = StringType() assert string_parameter_type.validate("a string") == "a string" with pytest.raises(ParameterValueError): string_parameter_type.validate(86) @@ -65,6 +82,114 @@ def test_parameter_types_validate(): boolean_parameter_type.validate(86) +def test_validate_postgres_connection(): + identifier = "polio-ff3a0d" + env_variable_prefix = stringcase.constcase(identifier) + host = "https://172.17.0.1" + port = "5432" + username = "dhis2" + password = "dhis2_pwd" + database_name = "polio" + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_HOST": host, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix}_PORT": port, + f"{env_variable_prefix}_DB_NAME": database_name, + }, + ): + postgres_parameter_type = PostgreSQLConnectionType() + assert postgres_parameter_type.validate(identifier) == PostgreSQLConnection( + host, int(port), username, password, database_name + ) + with pytest.raises(ParameterValueError): + postgres_parameter_type.validate(86) + + +def test_validate_dhis2_connection(): + identifier = "dhis2-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + dhsi2_parameter_type = DHIS2ConnectionType() + assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password) + with pytest.raises(ParameterValueError): + dhsi2_parameter_type.validate(86) + + +def test_validate_iaso_connection(): + identifier = "iaso-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.iaso.org/" + username = "iaso" + password = "iaso_pwd" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + iaso_parameter_type = IASOConnectionType() + assert iaso_parameter_type.validate(identifier) == IASOConnection(url, username, password) + with pytest.raises(ParameterValueError): + iaso_parameter_type.validate(86) + + +def test_validate_gcs_connection(): + identifier = "gcs-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + service_account_key = "HqQBxH0BAI3zF7kANUNlGg" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + gcs_parameter_type = GCSConnectionType() + assert gcs_parameter_type.validate(identifier) == GCSConnection(service_account_key, bucket_name) + with pytest.raises(ParameterValueError): + gcs_parameter_type.validate(86) + + +def test_validate_s3_connection(): + identifier = "s3-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" + access_key_id = "84hVntMaMSYP/RSW9ex04w" + bucket_name = "test" + + with mock.patch.dict( + os.environ, + { + f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key, + f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + s3_parameter_type = S3ConnectionType() + assert s3_parameter_type.validate(identifier) == S3Connection(access_key_id, secret_access_key, bucket_name) + with pytest.raises(ParameterValueError): + s3_parameter_type.validate(86) + + def test_parameter_init(): # Wrong type with pytest.raises(InvalidParameterError): @@ -228,7 +353,7 @@ def a_function(): assert function_parameters[0].multiple is False assert function_parameters[1].code == "arg2" - assert isinstance(function_parameters[1].type, String) + assert isinstance(function_parameters[1].type, StringType) assert function_parameters[1].name == "Arg 2" assert function_parameters[1].help == "Help 2" assert function_parameters[1].default == ["yo"] diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 15f54a7..ebb96b2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,16 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest +import stringcase +import os +from openhexa.sdk.workspaces.connection import ( + DHIS2Connection, + IASOConnection, + PostgreSQLConnection, + GCSConnection, + S3Connection, +) from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError from openhexa.sdk.pipelines.pipeline import Pipeline @@ -34,6 +43,135 @@ def test_pipeline_run_extra_config(): pipeline.run({"arg1": "ok", "arg2": "extra"}) +def test_pipeline_run_connection_dhis2_parameter_config(): + identifier = "dhis2-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.dhis2.org/" + username = "dhis2" + password = "dhis2_pwd" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter( + "connection_param", name="this is a test for connection parameter", type=DHIS2Connection + ) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=DHIS2Connection(url, username, password)) + + +def test_pipeline_run_connection_iaso_parameter_config(): + identifier = "iaso-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + url = "https://test.iaso.org/" + username = "iaso" + password = "iaso_pwd" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_URL": url, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=IASOConnection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=IASOConnection(url, username, password)) + + +def test_pipeline_run_connection_gcs_parameter_config(): + identifier = "gcs-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + service_account_key = "HqQBxH0BAI3zF7kANUNlGg" + bucket_name = "test" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=GCSConnection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with(connection_param=GCSConnection(service_account_key, bucket_name)) + + +def test_pipeline_run_connection_s3_parameter_config(): + identifier = "s3-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" + access_key_id = "84hVntMaMSYP/RSW9ex04w" + bucket_name = "test" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key, + f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id, + f"{env_variable_prefix}_BUCKET_NAME": bucket_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter("connection_param", name="this is a test for connection parameter", type=S3Connection) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with( + connection_param=S3Connection(access_key_id, secret_access_key, bucket_name) + ) + + +def test_pipeline_run_connection_postgres_parameter_config(): + identifier = "postgres-connection-id" + env_variable_prefix = stringcase.constcase(identifier) + host = "https://127.0.0.1" + port = "5432" + username = "hexa_sdk" + password = "hexa_sdk_pwd" + database_name = "hexa_sdk" + + with patch.dict( + os.environ, + { + f"{env_variable_prefix}_HOST": host, + f"{env_variable_prefix}_USERNAME": username, + f"{env_variable_prefix}_PASSWORD": password, + f"{env_variable_prefix}_PORT": port, + f"{env_variable_prefix}_DB_NAME": database_name, + }, + ): + pipeline_func = Mock() + parameter_1 = Parameter( + "connection_param", name="this is a test for connection parameter", type=PostgreSQLConnection + ) + pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) + pipeline.run({"connection_param": identifier}) + + assert pipeline.name == "pipeline" + pipeline_func.assert_called_once_with( + connection_param=PostgreSQLConnection(host, int(port), username, password, database_name) + ) + + def test_pipeline_parameters_spec(): pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str)