Skip to content

Commit

Permalink
feat : add parameter of type connection (IASO,GCS..) (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheikhgwane authored Nov 17, 2023
1 parent 03aa815 commit 1115864
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 16 deletions.
147 changes: 139 additions & 8 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import re
import typing
from openhexa.sdk.workspaces.connection import (
DHIS2Connection,
IASOConnection,
PostgreSQLConnection,
S3Connection,
GCSConnection,
)
from openhexa.sdk.workspaces import workspace
from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist


class ParameterValueError(Exception):
Expand Down Expand Up @@ -38,7 +47,7 @@ def normalize(value: typing.Any) -> typing.Any:

return value

def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) -> typing.Optional[typing.Any]:
def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]:
"""Validate the provided value for this type."""

if not isinstance(value, self.expected_type):
Expand All @@ -48,11 +57,14 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True)

return value

def validate_default(self, value: typing.Optional[typing.Any]):
self.validate(value)

def __str__(self) -> str:
return str(self.expected_type)


class String(ParameterType):
class StringType(ParameterType):
@property
def spec_type(self) -> str:
return "str"
Expand All @@ -73,11 +85,11 @@ def normalize(value: typing.Any) -> typing.Optional[str]:

return normalized_value

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
def validate_default(self, value: typing.Optional[typing.Any]):
if value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
super().validate_default(value)


class Boolean(ParameterType):
Expand Down Expand Up @@ -125,7 +137,126 @@ def normalize(value: typing.Any) -> typing.Any:
return value


TYPES_BY_PYTHON_TYPE = {str: String, bool: Boolean, int: Integer, float: Float}
class ConnectionParameterType(ParameterType):
@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate_default(self, value: typing.Optional[typing.Any]):
if value is None:
return

if not isinstance(value, str):
raise InvalidParameterError("Default value for connection parameter type should be string.")
elif value == "":
raise ParameterValueError("Empty values are not accepted.")

def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]:
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

try:
return self.to_connection(value)
except ConnectionDoesNotExist as e:
raise ParameterValueError(str(e))

def to_connection(self, value: str) -> typing.Any:
raise NotImplementedError


class PostgreSQLConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "postgresql"

@property
def expected_type(self) -> typing.Type:
return PostgreSQLConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.postgresql_connection(value)


class S3ConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "s3"

@property
def expected_type(self) -> typing.Type:
return S3ConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.s3_connection(value)


class GCSConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "gcs"

@property
def expected_type(self) -> typing.Type:
return GCSConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.gcs_connection(value)


class DHIS2ConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "dhis2"

@property
def expected_type(self) -> typing.Type:
return DHIS2ConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.dhis2_connection(value)


class IASOConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "iaso"

@property
def expected_type(self) -> typing.Type:
return IASOConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.iaso_connection(value)


class CustomConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "custom"

@property
def expected_type(self) -> typing.Type:
return str

def to_connection(self, value: str) -> typing.Any:
return workspace.postgresql_connection(value)


TYPES_BY_PYTHON_TYPE = {
str: StringType,
bool: Boolean,
int: Integer,
float: Float,
DHIS2Connection: DHIS2ConnectionType,
PostgreSQLConnection: PostgreSQLConnectionType,
IASOConnection: IASOConnectionType,
S3Connection: S3ConnectionType,
GCSConnection: GCSConnectionType,
}


class InvalidParameterError(Exception):
Expand Down Expand Up @@ -246,9 +377,9 @@ def _validate_default(self, default: typing.Any, multiple: bool):
if not isinstance(default, list):
raise InvalidParameterError("Default values should be lists when using multiple=True")
for default_value in default:
self.type.validate(default_value, allow_empty=False)
self.type.validate_default(default_value)
else:
self.type.validate(default, allow_empty=False)
self.type.validate_default(default)
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

Expand Down
6 changes: 5 additions & 1 deletion openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from openhexa.sdk.utils import Environments, get_environment

from .parameter import FunctionWithParameter, Parameter, ParameterValueError
from .parameter import (
FunctionWithParameter,
Parameter,
ParameterValueError,
)
from .task import PipelineWithTask
from .utils import get_local_workspace_config

Expand Down
137 changes: 131 additions & 6 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import pytest
import os
import stringcase

from openhexa.sdk.workspaces.connection import (
DHIS2Connection,
IASOConnection,
PostgreSQLConnection,
GCSConnection,
S3Connection,
)

from openhexa.sdk.pipelines.parameter import (
Boolean,
Expand All @@ -8,14 +18,21 @@
InvalidParameterError,
Parameter,
ParameterValueError,
String,
StringType,
PostgreSQLConnectionType,
GCSConnectionType,
S3ConnectionType,
IASOConnectionType,
DHIS2ConnectionType,
parameter,
)

from unittest import mock


def test_parameter_types_normalize():
# String
string_parameter_type = String()
# StringType
string_parameter_type = StringType()
assert string_parameter_type.normalize("a string") == "a string"
assert string_parameter_type.normalize(" a string ") == "a string"
assert string_parameter_type.normalize("") is None
Expand All @@ -39,8 +56,8 @@ def test_parameter_types_normalize():


def test_parameter_types_validate():
# String
string_parameter_type = String()
# StringType
string_parameter_type = StringType()
assert string_parameter_type.validate("a string") == "a string"
with pytest.raises(ParameterValueError):
string_parameter_type.validate(86)
Expand All @@ -65,6 +82,114 @@ def test_parameter_types_validate():
boolean_parameter_type.validate(86)


def test_validate_postgres_connection():
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
host = "https://172.17.0.1"
port = "5432"
username = "dhis2"
password = "dhis2_pwd"
database_name = "polio"
with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_HOST": host,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
f"{env_variable_prefix}_PORT": port,
f"{env_variable_prefix}_DB_NAME": database_name,
},
):
postgres_parameter_type = PostgreSQLConnectionType()
assert postgres_parameter_type.validate(identifier) == PostgreSQLConnection(
host, int(port), username, password, database_name
)
with pytest.raises(ParameterValueError):
postgres_parameter_type.validate(86)


def test_validate_dhis2_connection():
identifier = "dhis2-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
username = "dhis2"
password = "dhis2_pwd"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_URL": url,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
},
):
dhsi2_parameter_type = DHIS2ConnectionType()
assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password)
with pytest.raises(ParameterValueError):
dhsi2_parameter_type.validate(86)


def test_validate_iaso_connection():
identifier = "iaso-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.iaso.org/"
username = "iaso"
password = "iaso_pwd"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_URL": url,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
},
):
iaso_parameter_type = IASOConnectionType()
assert iaso_parameter_type.validate(identifier) == IASOConnection(url, username, password)
with pytest.raises(ParameterValueError):
iaso_parameter_type.validate(86)


def test_validate_gcs_connection():
identifier = "gcs-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
service_account_key = "HqQBxH0BAI3zF7kANUNlGg"
bucket_name = "test"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key,
f"{env_variable_prefix}_BUCKET_NAME": bucket_name,
},
):
gcs_parameter_type = GCSConnectionType()
assert gcs_parameter_type.validate(identifier) == GCSConnection(service_account_key, bucket_name)
with pytest.raises(ParameterValueError):
gcs_parameter_type.validate(86)


def test_validate_s3_connection():
identifier = "s3-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
secret_access_key = "HqQBxH0BAI3zF7kANUNlGg"
access_key_id = "84hVntMaMSYP/RSW9ex04w"
bucket_name = "test"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key,
f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id,
f"{env_variable_prefix}_BUCKET_NAME": bucket_name,
},
):
s3_parameter_type = S3ConnectionType()
assert s3_parameter_type.validate(identifier) == S3Connection(access_key_id, secret_access_key, bucket_name)
with pytest.raises(ParameterValueError):
s3_parameter_type.validate(86)


def test_parameter_init():
# Wrong type
with pytest.raises(InvalidParameterError):
Expand Down Expand Up @@ -228,7 +353,7 @@ def a_function():
assert function_parameters[0].multiple is False

assert function_parameters[1].code == "arg2"
assert isinstance(function_parameters[1].type, String)
assert isinstance(function_parameters[1].type, StringType)
assert function_parameters[1].name == "Arg 2"
assert function_parameters[1].help == "Help 2"
assert function_parameters[1].default == ["yo"]
Expand Down
Loading

0 comments on commit 1115864

Please sign in to comment.