Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : add parameter of type connection (IASO,GCS..) #86

Merged
merged 5 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 139 additions & 8 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import re
import typing
from openhexa.sdk.workspaces.connection import (
DHIS2Connection,
IASOConnection,
PostgreSQLConnection,
S3Connection,
GCSConnection,
)
from openhexa.sdk.workspaces import workspace
from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist


class ParameterValueError(Exception):
Expand Down Expand Up @@ -38,7 +47,7 @@ def normalize(value: typing.Any) -> typing.Any:

return value

def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True) -> typing.Optional[typing.Any]:
def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]:
"""Validate the provided value for this type."""

if not isinstance(value, self.expected_type):
Expand All @@ -48,11 +57,14 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True)

return value

def validate_default(self, value: typing.Optional[typing.Any]):
self.validate(value)

def __str__(self) -> str:
return str(self.expected_type)


class String(ParameterType):
class StringType(ParameterType):
@property
def spec_type(self) -> str:
return "str"
Expand All @@ -73,11 +85,11 @@ def normalize(value: typing.Any) -> typing.Optional[str]:

return normalized_value

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
def validate_default(self, value: typing.Optional[typing.Any]):
if value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
super().validate_default(value)


class Boolean(ParameterType):
Expand Down Expand Up @@ -125,7 +137,126 @@ def normalize(value: typing.Any) -> typing.Any:
return value


TYPES_BY_PYTHON_TYPE = {str: String, bool: Boolean, int: Integer, float: Float}
class ConnectionParameterType(ParameterType):
@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate_default(self, value: typing.Optional[typing.Any]):
if value is None:
return

if not isinstance(value, str):
raise InvalidParameterError("Default value for connection parameter type should be string.")
elif value == "":
raise ParameterValueError("Empty values are not accepted.")

def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]:
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

try:
return self.to_connection(value)
except ConnectionDoesNotExist as e:
raise ParameterValueError(str(e))

def to_connection(self, value: str) -> typing.Any:
raise NotImplementedError


class PostgreSQLConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "postgresql"

@property
def expected_type(self) -> typing.Type:
return PostgreSQLConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.postgresql_connection(value)


class S3ConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "s3"

@property
def expected_type(self) -> typing.Type:
return S3ConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.s3_connection(value)


class GCSConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "gcs"

@property
def expected_type(self) -> typing.Type:
return GCSConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.gcs_connection(value)


class DHIS2ConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "dhis2"

@property
def expected_type(self) -> typing.Type:
return DHIS2ConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.dhis2_connection(value)


class IASOConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "iaso"

@property
def expected_type(self) -> typing.Type:
return IASOConnectionType

def to_connection(self, value: str) -> typing.Any:
return workspace.iaso_connection(value)


class CustomConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "custom"

@property
def expected_type(self) -> typing.Type:
return str

def to_connection(self, value: str) -> typing.Any:
return workspace.postgresql_connection(value)


TYPES_BY_PYTHON_TYPE = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think that it would be a lot of added work to rename String to StringType, Boolean to BooleanType and so on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try

str: StringType,
bool: Boolean,
int: Integer,
float: Float,
DHIS2Connection: DHIS2ConnectionType,
PostgreSQLConnection: PostgreSQLConnectionType,
IASOConnection: IASOConnectionType,
S3Connection: S3ConnectionType,
GCSConnection: GCSConnectionType,
}


class InvalidParameterError(Exception):
Expand Down Expand Up @@ -246,9 +377,9 @@ def _validate_default(self, default: typing.Any, multiple: bool):
if not isinstance(default, list):
raise InvalidParameterError("Default values should be lists when using multiple=True")
for default_value in default:
self.type.validate(default_value, allow_empty=False)
self.type.validate_default(default_value)
else:
self.type.validate(default, allow_empty=False)
self.type.validate_default(default)
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

Expand Down
6 changes: 5 additions & 1 deletion openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from openhexa.sdk.utils import Environments, get_environment

from .parameter import FunctionWithParameter, Parameter, ParameterValueError
from .parameter import (
FunctionWithParameter,
Parameter,
ParameterValueError,
)
from .task import PipelineWithTask
from .utils import get_local_workspace_config

Expand Down
137 changes: 131 additions & 6 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import pytest
import os
import stringcase

from openhexa.sdk.workspaces.connection import (
DHIS2Connection,
IASOConnection,
PostgreSQLConnection,
GCSConnection,
S3Connection,
)

from openhexa.sdk.pipelines.parameter import (
Boolean,
Expand All @@ -8,14 +18,21 @@
InvalidParameterError,
Parameter,
ParameterValueError,
String,
StringType,
PostgreSQLConnectionType,
GCSConnectionType,
S3ConnectionType,
IASOConnectionType,
DHIS2ConnectionType,
parameter,
)

from unittest import mock


def test_parameter_types_normalize():
# String
string_parameter_type = String()
# StringType
string_parameter_type = StringType()
assert string_parameter_type.normalize("a string") == "a string"
assert string_parameter_type.normalize(" a string ") == "a string"
assert string_parameter_type.normalize("") is None
Expand All @@ -39,8 +56,8 @@ def test_parameter_types_normalize():


def test_parameter_types_validate():
# String
string_parameter_type = String()
# StringType
string_parameter_type = StringType()
assert string_parameter_type.validate("a string") == "a string"
with pytest.raises(ParameterValueError):
string_parameter_type.validate(86)
Expand All @@ -65,6 +82,114 @@ def test_parameter_types_validate():
boolean_parameter_type.validate(86)


def test_validate_postgres_connection():
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
host = "https://172.17.0.1"
port = "5432"
username = "dhis2"
password = "dhis2_pwd"
database_name = "polio"
with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_HOST": host,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
f"{env_variable_prefix}_PORT": port,
f"{env_variable_prefix}_DB_NAME": database_name,
},
):
postgres_parameter_type = PostgreSQLConnectionType()
assert postgres_parameter_type.validate(identifier) == PostgreSQLConnection(
host, int(port), username, password, database_name
)
with pytest.raises(ParameterValueError):
postgres_parameter_type.validate(86)


def test_validate_dhis2_connection():
identifier = "dhis2-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
username = "dhis2"
password = "dhis2_pwd"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_URL": url,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
},
):
dhsi2_parameter_type = DHIS2ConnectionType()
assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password)
with pytest.raises(ParameterValueError):
dhsi2_parameter_type.validate(86)


def test_validate_iaso_connection():
identifier = "iaso-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.iaso.org/"
username = "iaso"
password = "iaso_pwd"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_URL": url,
f"{env_variable_prefix}_USERNAME": username,
f"{env_variable_prefix}_PASSWORD": password,
},
):
iaso_parameter_type = IASOConnectionType()
assert iaso_parameter_type.validate(identifier) == IASOConnection(url, username, password)
with pytest.raises(ParameterValueError):
iaso_parameter_type.validate(86)


def test_validate_gcs_connection():
identifier = "gcs-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
service_account_key = "HqQBxH0BAI3zF7kANUNlGg"
bucket_name = "test"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_SERVICE_ACCOUNT_KEY": service_account_key,
f"{env_variable_prefix}_BUCKET_NAME": bucket_name,
},
):
gcs_parameter_type = GCSConnectionType()
assert gcs_parameter_type.validate(identifier) == GCSConnection(service_account_key, bucket_name)
with pytest.raises(ParameterValueError):
gcs_parameter_type.validate(86)


def test_validate_s3_connection():
identifier = "s3-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
secret_access_key = "HqQBxH0BAI3zF7kANUNlGg"
access_key_id = "84hVntMaMSYP/RSW9ex04w"
bucket_name = "test"

with mock.patch.dict(
os.environ,
{
f"{env_variable_prefix}_SECRET_ACCESS_KEY": secret_access_key,
f"{env_variable_prefix}_ACCESS_KEY_ID": access_key_id,
f"{env_variable_prefix}_BUCKET_NAME": bucket_name,
},
):
s3_parameter_type = S3ConnectionType()
assert s3_parameter_type.validate(identifier) == S3Connection(access_key_id, secret_access_key, bucket_name)
with pytest.raises(ParameterValueError):
s3_parameter_type.validate(86)


def test_parameter_init():
# Wrong type
with pytest.raises(InvalidParameterError):
Expand Down Expand Up @@ -228,7 +353,7 @@ def a_function():
assert function_parameters[0].multiple is False

assert function_parameters[1].code == "arg2"
assert isinstance(function_parameters[1].type, String)
assert isinstance(function_parameters[1].type, StringType)
assert function_parameters[1].name == "Arg 2"
assert function_parameters[1].help == "Help 2"
assert function_parameters[1].default == ["yo"]
Expand Down
Loading
Loading