Skip to content

Commit

Permalink
fix: fixes the default value to be of type list (#211)
Browse files Browse the repository at this point in the history
* fix: fixes the default value to be of type list
Closes HEXA-1037 : Default values for parameters set as "multiple" fails on run of the pipeline

---------

Co-authored-by: Quentin Gérôme <[email protected]>
  • Loading branch information
nazarfil and qgerome authored Oct 2, 2024
1 parent 0bfd5f7 commit 151e3af
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 93 deletions.
9 changes: 4 additions & 5 deletions openhexa/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile
import typing
from dataclasses import asdict
from importlib.metadata import version
from pathlib import Path
from zipfile import ZipFile
Expand All @@ -21,7 +20,7 @@

from openhexa.cli.settings import settings
from openhexa.sdk.pipelines import get_local_workspace_config
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline
from openhexa.utils import create_requests_session, stringcase


Expand Down Expand Up @@ -195,7 +194,7 @@ def list_pipelines():
return data["pipelines"]["items"]


def get_pipeline(pipeline_code: str) -> dict[str, typing.Any]:
def get_pipeline_from_code(pipeline_code: str) -> dict[str, typing.Any]:
"""Get a single pipeline."""
if settings.current_workspace is None:
raise NoActiveWorkspaceError
Expand Down Expand Up @@ -543,7 +542,7 @@ def upload_pipeline(
raise NoActiveWorkspaceError

directory = pipeline_directory_path.absolute()
pipeline = get_pipeline_metadata(directory)
pipeline = get_pipeline(directory)
zip_file = generate_zip_file(directory)

if settings.debug:
Expand Down Expand Up @@ -574,7 +573,7 @@ def upload_pipeline(
"description": description,
"externalLink": link,
"zipfile": base64_content,
"parameters": [asdict(p) for p in pipeline.parameters],
"parameters": [p.to_dict() for p in pipeline.parameters],
"timeout": pipeline.timeout,
}
},
Expand Down
10 changes: 5 additions & 5 deletions openhexa/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
download_pipeline_sourcecode,
ensure_is_pipeline_dir,
get_library_versions,
get_pipeline,
get_pipeline_from_code,
get_workspace,
list_pipelines,
run_pipeline,
upload_pipeline,
)
from openhexa.cli.settings import settings, setup_logging
from openhexa.sdk.pipelines.exceptions import PipelineNotFound
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline


def validate_url(ctx, param, value):
Expand Down Expand Up @@ -283,7 +283,7 @@ def pipelines_push(
ensure_is_pipeline_dir(path)

try:
pipeline = get_pipeline_metadata(path)
pipeline = get_pipeline(path)
except PipelineNotFound:
_terminate(
f"❌ No function with openhexa.sdk pipeline decorator found in {click.style(path, bold=True)}.",
Expand All @@ -296,7 +296,7 @@ def pipelines_push(
if settings.debug:
click.echo(workspace_pipelines)

if get_pipeline(pipeline.code) is None:
if get_pipeline_from_code(pipeline.code) is None:
click.echo(
f"Pipeline {click.style(pipeline.code, bold=True)} does not exist in workspace {click.style(workspace, bold=True)}"
)
Expand Down Expand Up @@ -374,7 +374,7 @@ def pipelines_delete(code: str):
err=True,
)
else:
pipeline = get_pipeline(code)
pipeline = get_pipeline_from_code(code)
if pipeline is None:
_terminate(
f"❌ Pipeline {click.style(code, bold=True)} does not exist in workspace {click.style(settings.current_workspace, bold=True)}"
Expand Down
25 changes: 23 additions & 2 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,19 @@ def validate(self, value: typing.Any) -> typing.Any:
else:
return self._validate_single(value)

def to_dict(self) -> dict[str, typing.Any]:
"""Return a dictionary representation of the Parameter instance."""
return {
"code": self.code,
"type": self.type.spec_type,
"name": self.name,
"choices": self.choices,
"help": self.help,
"default": self.default,
"required": self.required,
"multiple": self.multiple,
}

def _validate_single(self, value: typing.Any):
# Normalize empty values to None and handles default
normalized_value = self.type.normalize(value)
Expand Down Expand Up @@ -487,8 +500,16 @@ def _validate_default(self, default: typing.Any, multiple: bool):
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

if self.choices is not None and default not in self.choices:
raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.")
if self.choices is not None:
if isinstance(default, list):
if not all(d in self.choices for d in default):
raise InvalidParameterError(
f"The default list of values for {self.code} is not included in the provided choices."
)
elif default not in self.choices:
raise InvalidParameterError(
f"The default value for {self.code} is not included in the provided choices."
)

def parameter_spec(self) -> dict[str, typing.Any]:
"""Build specification for this parameter, to be provided to the OpenHEXA backend."""
Expand Down
11 changes: 11 additions & 0 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ def parameters_spec(self) -> list[dict[str, typing.Any]]:
"""Return the individual specifications of all the parameters of this pipeline."""
return [arg.parameter_spec() for arg in self.parameters]

def to_dict(self):
"""Return a dictionary representation of the pipeline."""
return {
"code": self.code,
"name": self.name,
"parameters": [p.to_dict() for p in self.parameters],
"timeout": self.timeout,
"function": self.function.__dict__ if self.function else None,
"tasks": [t.__dict__ for t in self.tasks],
}

def _get_available_tasks(self) -> list[Task]:
return [task for task in self.tasks if task.is_ready()]

Expand Down
71 changes: 19 additions & 52 deletions openhexa/sdk/pipelines/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,19 @@

import requests

from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE
from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code
from openhexa.sdk.pipelines.exceptions import PipelineNotFound
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE, Parameter

from .pipeline import Pipeline


@dataclass
class PipelineParameterSpecs:
"""Specification of a pipeline parameter."""

code: string
type: string
name: string
choices: list[typing.Union[str, int, float]]
help: string
default: typing.Any
required: bool = True
multiple: bool = False

def __post_init__(self):
"""Validate the parameter and set default values."""
if self.default and self.choices and self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
validate_pipeline_parameter_code(self.code)
if self.required is None:
self.required = True
if self.multiple is None:
self.multiple = False


@dataclass
class Argument:
"""Argument of a decorator."""

name: string
types: list[typing.Any] = field(default_factory=list)


@dataclass
class PipelineSpecs:
"""Specification of a pipeline."""

code: string
name: string
timeout: int = None
parameters: list[PipelineParameterSpecs] = field(default_factory=list)
default_value: typing.Any = None


def import_pipeline(pipeline_dir_path: str):
Expand Down Expand Up @@ -124,10 +90,10 @@ def _get_decorator_arg_value(decorator, arg: Argument, index: int):
try:
return decorator.args[index].value
except IndexError:
return None
return arg.default_value


def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
def _get_decorator_spec(decorator, args: tuple[Argument]):
d = {"name": decorator.func.id, "args": {}}

for i, arg in enumerate(args):
Expand All @@ -136,8 +102,8 @@ def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
return d


def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
"""Return the pipeline metadata from the pipeline code.
def get_pipeline(pipeline_path: Path) -> Pipeline:
"""Return the pipeline with metadata and parameters from the pipeline code.
Args:
pipeline_path (Path): Path to the pipeline directory
Expand All @@ -150,7 +116,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Returns
-------
typing.Tuple[PipelineSpecs, typing.List[PipelineParameterSpecs]]: A tuple containing the pipeline specs and the list of parameters specs.
Pipeline: The pipeline object with parameters and metadata.
"""
tree = ast.parse(open(Path(pipeline_path) / "pipeline.py").read())
pipeline = None
Expand All @@ -170,7 +136,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("timeout", [ast.Constant]),
),
)
pipeline = PipelineSpecs(**pipeline_decorator_spec["args"])
pipelines_parameters = []
for parameter_decorator in _get_decorators_by_name(node, "parameter"):
param_decorator_spec = _get_decorator_spec(
parameter_decorator,
Expand All @@ -180,19 +146,20 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("name", [ast.Constant]),
Argument("choices", [ast.List]),
Argument("help", [ast.Constant]),
Argument("default", [ast.Constant]),
Argument("required", [ast.Constant]),
Argument("multiple", [ast.Constant]),
Argument("default", [ast.Constant, ast.List]),
Argument("required", [ast.Constant], default_value=True),
Argument("multiple", [ast.Constant], default_value=False),
),
)
parameter_args = param_decorator_spec["args"]
try:
args = param_decorator_spec["args"]
inst = TYPES_BY_PYTHON_TYPE[args["type"]]()
args["type"] = inst.spec_type

pipeline.parameters.append(PipelineParameterSpecs(**args))
type_class = TYPES_BY_PYTHON_TYPE[parameter_args.pop("type")]()
except KeyError:
raise InvalidParameterError(f"Invalid parameter type {args['type']}")
raise ValueError(f"Unsupported parameter type: {parameter_args['type']}")
parameter = Parameter(type=type_class.expected_type, **parameter_args)
pipelines_parameters.append(parameter)

pipeline = Pipeline(parameters=pipelines_parameters, function=None, **pipeline_decorator_spec["args"])

if pipeline is None:
raise PipelineNotFound("No function with openhexa.sdk pipeline decorator found.")
Expand Down
1 change: 1 addition & 0 deletions openhexa/utils/stringcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Coming from https://github.com/okunishinishi/python-stringcase
"""

import re


Expand Down
Loading

0 comments on commit 151e3af

Please sign in to comment.