Skip to content

Commit

Permalink
feat : inject connection instance instead of identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
cheikhgwane committed Nov 15, 2023
1 parent 9324b65 commit 34e617c
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 21 deletions.
120 changes: 104 additions & 16 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import re
import typing
from openhexa.sdk.workspaces.connection import (
DHIS2Connection,
IASOConnection,
PostgreSQLConnection,
S3Connection,
GCSConnection,
)


class ParameterValueError(Exception):
Expand Down Expand Up @@ -132,13 +139,24 @@ def spec_type(self) -> str:

@property
def expected_type(self) -> typing.Type:
return str
return PostgreSQLConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class S3ConnectionType(ParameterType):
Expand All @@ -148,13 +166,24 @@ def spec_type(self) -> str:

@property
def expected_type(self) -> typing.Type:
return str
return S3ConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class GCSConnectionType(ParameterType):
Expand All @@ -164,13 +193,24 @@ def spec_type(self) -> str:

@property
def expected_type(self) -> typing.Type:
return str
return GCSConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class DHIS2ConnectionType(ParameterType):
Expand All @@ -180,13 +220,24 @@ def spec_type(self) -> str:

@property
def expected_type(self) -> typing.Type:
return str
return DHIS2ConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class IASOConnectionType(ParameterType):
Expand All @@ -196,13 +247,24 @@ def spec_type(self) -> str:

@property
def expected_type(self) -> typing.Type:
return str
return IASOConnectionType

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

return super().validate(value, allow_empty)
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return value


class CustomConnectionType(ParameterType):
Expand All @@ -214,10 +276,21 @@ def spec_type(self) -> str:
def expected_type(self) -> typing.Type:
return str

@property
def accepts_choice(self) -> bool:
return False

@property
def accepts_multiple(self) -> bool:
return False

def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = True) -> typing.Optional[str]:
if not allow_empty and value == "":
raise ParameterValueError("Empty values are not accepted.")

if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")

return super().validate(value, allow_empty)


Expand All @@ -226,12 +299,11 @@ def validate(self, value: typing.Optional[typing.Any], *, allow_empty: bool = Tr
bool: Boolean,
int: Integer,
float: Float,
"dhis2": DHIS2ConnectionType,
"postgresql": PostgreSQLConnectionType,
"iaso": IASOConnectionType,
"s3": S3ConnectionType,
"gcs": GCSConnectionType,
"custom": CustomConnectionType,
DHIS2Connection: DHIS2ConnectionType,
PostgreSQLConnection: PostgreSQLConnectionType,
IASOConnection: IASOConnectionType,
S3Connection: S3ConnectionType,
GCSConnection: GCSConnectionType,
}


Expand Down Expand Up @@ -450,3 +522,19 @@ def get_all_parameters(self):
return [self.parameter, *self.function.get_all_parameters()]

return [self.parameter]


def is_connection_parameter(param: Parameter):
return any(
[
isinstance(param.type, type)
for type in [
DHIS2ConnectionType,
PostgreSQLConnectionType,
IASOConnectionType,
S3ConnectionType,
GCSConnectionType,
CustomConnectionType,
]
]
)
14 changes: 11 additions & 3 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@

from openhexa.sdk.utils import Environments, get_environment

from .parameter import FunctionWithParameter, Parameter, ParameterValueError
from .parameter import (
FunctionWithParameter,
Parameter,
ParameterValueError,
is_connection_parameter,
)
from .task import PipelineWithTask
from .utils import get_local_workspace_config
from .utils import get_local_workspace_config, get_connection_by_type

logger = getLogger(__name__)

Expand Down Expand Up @@ -97,7 +102,10 @@ def run(self, config: typing.Dict[str, typing.Any]):
for parameter in self.parameters:
value = config.pop(parameter.code, None)
validated_value = parameter.validate(value)
validated_config[parameter.code] = validated_value
if is_connection_parameter(parameter):
validated_config[parameter.code] = get_connection_by_type(parameter.type, validated_value)
else:
validated_config[parameter.code] = validated_value

if len(config) > 0:
raise ParameterValueError(f"The provided config contains invalid key(s): {', '.join(list(config.keys()))}")
Expand Down
30 changes: 30 additions & 0 deletions openhexa/sdk/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import stringcase
import yaml

from openhexa.sdk.workspaces import workspace
from .parameter import (
DHIS2ConnectionType,
PostgreSQLConnectionType,
IASOConnectionType,
S3ConnectionType,
GCSConnectionType,
CustomConnectionType,
)


class LocalWorkspaceConfigError(Exception):
pass
Expand Down Expand Up @@ -148,3 +158,23 @@ def get_local_workspace_config(path: Path):
if key != "type":
env_vars[stringcase.constcase(f"{slug}_{key.lower()}")] = str(value)
return env_vars


def get_connection_by_type(type: any, identifier: str):
if isinstance(type, DHIS2ConnectionType):
return workspace.dhis2_connection(identifier)

if isinstance(type, PostgreSQLConnectionType):
return workspace.postgresql_connection(identifier)

if isinstance(type, IASOConnectionType):
return workspace.iaso_connection(identifier)

if isinstance(type, S3ConnectionType):
return workspace.s3_connection(identifier)

if isinstance(type, GCSConnectionType):
return workspace.gcs_connection(identifier)

if isinstance(type, CustomConnectionType):
return workspace.custom_connection(identifier)
Loading

0 comments on commit 34e617c

Please sign in to comment.