Skip to content

Commit

Permalink
refacto : inject connection on Parameter.validate
Browse files Browse the repository at this point in the history
  • Loading branch information
cheikhgwane committed Nov 17, 2023
1 parent 34e617c commit fbe3f00
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 181 deletions.
154 changes: 47 additions & 107 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
S3Connection,
GCSConnection,
)
from openhexa.sdk.workspaces import workspace


class ParameterValueError(Exception):
Expand Down Expand Up @@ -55,6 +56,11 @@ def validate(self, value: typing.Optional[typing.Any], allow_empty: bool = True)

return value

def validate_default(
self, value: typing.Optional[typing.Any], allow_empty: bool = True
) -> typing.Optional[typing.Any]:
return self.validate(value, allow_empty=allow_empty)

def __str__(self) -> str:
return str(self.expected_type)

Expand Down Expand Up @@ -132,15 +138,7 @@ def normalize(value: typing.Any) -> typing.Any:
return value


class PostgreSQLConnectionType(ParameterType):
@property
def spec_type(self) -> str:
return "postgresql"

@property
def expected_type(self) -> typing.Type:
return PostgreSQLConnectionType

class ConnectionParameterType(ParameterType):
@property
def accepts_choice(self) -> bool:
return False
Expand All @@ -149,44 +147,54 @@ def accepts_choice(self) -> bool:
def accepts_multiple(self) -> bool:
return False

def validate_default(
self, value: typing.Optional[typing.Any], allow_empty: bool = True
) -> typing.Optional[typing.Any]:
if value is None:
return

if not isinstance(value, str):
raise InvalidParameterError("Default value for connection parameter type should be string.")
elif value == "":
raise ParameterValueError("Empty values are not accepted.")

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class S3ConnectionType(ParameterType):
class PostgreSQLConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "s3"
return "postgresql"

@property
def expected_type(self) -> typing.Type:
return S3ConnectionType
return PostgreSQLConnectionType

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
super().validate(value, allow_empty=allow_empty)
return workspace.postgresql_connection(value)


class S3ConnectionType(ConnectionParameterType):
@property
def accepts_choice(self) -> bool:
return False
def spec_type(self) -> str:
return "s3"

@property
def accepts_multiple(self) -> bool:
return False
def expected_type(self) -> typing.Type:
return S3ConnectionType

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value
super().validate(value, allow_empty=allow_empty)
return workspace.s3_connection(value)


class GCSConnectionType(ParameterType):
class GCSConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "gcs"
Expand All @@ -195,25 +203,12 @@ def spec_type(self) -> str:
def expected_type(self) -> typing.Type:
return GCSConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")
super().validate(value, allow_empty=allow_empty)
return workspace.gcs_connection(value)

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class DHIS2ConnectionType(ParameterType):
class DHIS2ConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "dhis2"
Expand All @@ -222,25 +217,12 @@ def spec_type(self) -> str:
def expected_type(self) -> typing.Type:
return DHIS2ConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value
super().validate(value, allow_empty=allow_empty)
return workspace.dhis2_connection(value)


class IASOConnectionType(ParameterType):
class IASOConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "iaso"
Expand All @@ -249,25 +231,12 @@ def spec_type(self) -> str:
def expected_type(self) -> typing.Type:
return IASOConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")
super().validate(value, allow_empty=allow_empty)
return workspace.iaso_connection(value)

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class CustomConnectionType(ParameterType):
class CustomConnectionType(ConnectionParameterType):
@property
def spec_type(self) -> str:
return "custom"
Expand All @@ -276,22 +245,9 @@ def spec_type(self) -> str:
def expected_type(self) -> typing.Type:
return str

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return super().validate(value, allow_empty)
super().validate(value, allow_empty=allow_empty)
return workspace.postgresql_connection(value)


TYPES_BY_PYTHON_TYPE = {
Expand Down Expand Up @@ -425,9 +381,9 @@ def _validate_default(self, default: typing.Any, multiple: bool):
if not isinstance(default, list):
raise InvalidParameterError("Default values should be lists when using multiple=True")
for default_value in default:
self.type.validate(default_value, allow_empty=False)
self.type.validate_default(default_value, allow_empty=False)
else:
self.type.validate(default, allow_empty=False)
self.type.validate_default(default, allow_empty=False)
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

Expand Down Expand Up @@ -522,19 +478,3 @@ def get_all_parameters(self):
return [self.parameter, *self.function.get_all_parameters()]

return [self.parameter]


def is_connection_parameter(param: Parameter):
return any(
[
isinstance(param.type, type)
for type in [
DHIS2ConnectionType,
PostgreSQLConnectionType,
IASOConnectionType,
S3ConnectionType,
GCSConnectionType,
CustomConnectionType,
]
]
)
8 changes: 2 additions & 6 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
FunctionWithParameter,
Parameter,
ParameterValueError,
is_connection_parameter,
)
from .task import PipelineWithTask
from .utils import get_local_workspace_config, get_connection_by_type
from .utils import get_local_workspace_config

logger = getLogger(__name__)

Expand Down Expand Up @@ -102,10 +101,7 @@ def run(self, config: typing.Dict[str, typing.Any]):
for parameter in self.parameters:
value = config.pop(parameter.code, None)
validated_value = parameter.validate(value)
if is_connection_parameter(parameter):
validated_config[parameter.code] = get_connection_by_type(parameter.type, validated_value)
else:
validated_config[parameter.code] = validated_value
validated_config[parameter.code] = validated_value

if len(config) > 0:
raise ParameterValueError(f"The provided config contains invalid key(s): {', '.join(list(config.keys()))}")
Expand Down
30 changes: 0 additions & 30 deletions openhexa/sdk/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@
import stringcase
import yaml

from openhexa.sdk.workspaces import workspace
from .parameter import (
DHIS2ConnectionType,
PostgreSQLConnectionType,
IASOConnectionType,
S3ConnectionType,
GCSConnectionType,
CustomConnectionType,
)


class LocalWorkspaceConfigError(Exception):
pass
Expand Down Expand Up @@ -158,23 +148,3 @@ def get_local_workspace_config(path: Path):
if key != "type":
env_vars[stringcase.constcase(f"{slug}_{key.lower()}")] = str(value)
return env_vars


def get_connection_by_type(type: any, identifier: str):
if isinstance(type, DHIS2ConnectionType):
return workspace.dhis2_connection(identifier)

if isinstance(type, PostgreSQLConnectionType):
return workspace.postgresql_connection(identifier)

if isinstance(type, IASOConnectionType):
return workspace.iaso_connection(identifier)

if isinstance(type, S3ConnectionType):
return workspace.s3_connection(identifier)

if isinstance(type, GCSConnectionType):
return workspace.gcs_connection(identifier)

if isinstance(type, CustomConnectionType):
return workspace.custom_connection(identifier)
Loading

0 comments on commit fbe3f00

Please sign in to comment.