From 151e3af419d5ea142781ce306c6c5bad3dc36239 Mon Sep 17 00:00:00 2001 From: Nazar F Date: Wed, 2 Oct 2024 11:26:07 +0200 Subject: [PATCH] fix: fixes the default value to be of type list (#211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fixes the default value to be of type list Closes HEXA-1037 : Default values for parameters set as "multiple" fails on run of the pipeline --------- Co-authored-by: Quentin Gérôme --- openhexa/cli/api.py | 9 ++- openhexa/cli/cli.py | 10 ++-- openhexa/sdk/pipelines/parameter.py | 25 +++++++- openhexa/sdk/pipelines/pipeline.py | 11 ++++ openhexa/sdk/pipelines/runtime.py | 71 ++++++---------------- openhexa/utils/stringcase.py | 1 + tests/test_ast.py | 91 ++++++++++++++++++++--------- tests/test_parameter.py | 2 +- 8 files changed, 127 insertions(+), 93 deletions(-) diff --git a/openhexa/cli/api.py b/openhexa/cli/api.py index 24685a6..a140cda 100644 --- a/openhexa/cli/api.py +++ b/openhexa/cli/api.py @@ -8,7 +8,6 @@ import os import tempfile import typing -from dataclasses import asdict from importlib.metadata import version from pathlib import Path from zipfile import ZipFile @@ -21,7 +20,7 @@ from openhexa.cli.settings import settings from openhexa.sdk.pipelines import get_local_workspace_config -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.runtime import get_pipeline from openhexa.utils import create_requests_session, stringcase @@ -195,7 +194,7 @@ def list_pipelines(): return data["pipelines"]["items"] -def get_pipeline(pipeline_code: str) -> dict[str, typing.Any]: +def get_pipeline_from_code(pipeline_code: str) -> dict[str, typing.Any]: """Get a single pipeline.""" if settings.current_workspace is None: raise NoActiveWorkspaceError @@ -543,7 +542,7 @@ def upload_pipeline( raise NoActiveWorkspaceError directory = pipeline_directory_path.absolute() - pipeline = get_pipeline_metadata(directory) + pipeline = get_pipeline(directory) zip_file = generate_zip_file(directory) if settings.debug: @@ -574,7 +573,7 @@ def upload_pipeline( "description": description, "externalLink": link, "zipfile": base64_content, - "parameters": [asdict(p) for p in pipeline.parameters], + "parameters": [p.to_dict() for p in pipeline.parameters], "timeout": pipeline.timeout, } }, diff --git a/openhexa/cli/cli.py b/openhexa/cli/cli.py index 6ff05bc..30cf788 100644 --- a/openhexa/cli/cli.py +++ b/openhexa/cli/cli.py @@ -21,7 +21,7 @@ download_pipeline_sourcecode, ensure_is_pipeline_dir, get_library_versions, - get_pipeline, + get_pipeline_from_code, get_workspace, list_pipelines, run_pipeline, @@ -29,7 +29,7 @@ ) from openhexa.cli.settings import settings, setup_logging from openhexa.sdk.pipelines.exceptions import PipelineNotFound -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.runtime import get_pipeline def validate_url(ctx, param, value): @@ -283,7 +283,7 @@ def pipelines_push( ensure_is_pipeline_dir(path) try: - pipeline = get_pipeline_metadata(path) + pipeline = get_pipeline(path) except PipelineNotFound: _terminate( f"❌ No function with openhexa.sdk pipeline decorator found in {click.style(path, bold=True)}.", @@ -296,7 +296,7 @@ def pipelines_push( if settings.debug: click.echo(workspace_pipelines) - if get_pipeline(pipeline.code) is None: + if get_pipeline_from_code(pipeline.code) is None: click.echo( f"Pipeline {click.style(pipeline.code, bold=True)} does not exist in workspace {click.style(workspace, bold=True)}" ) @@ -374,7 +374,7 @@ def pipelines_delete(code: str): err=True, ) else: - pipeline = get_pipeline(code) + pipeline = get_pipeline_from_code(code) if pipeline is None: _terminate( f"❌ Pipeline {click.style(code, bold=True)} does not exist in workspace {click.style(settings.current_workspace, bold=True)}" diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 203e217..f6451f2 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -429,6 +429,19 @@ def validate(self, value: typing.Any) -> typing.Any: else: return self._validate_single(value) + def to_dict(self) -> dict[str, typing.Any]: + """Return a dictionary representation of the Parameter instance.""" + return { + "code": self.code, + "type": self.type.spec_type, + "name": self.name, + "choices": self.choices, + "help": self.help, + "default": self.default, + "required": self.required, + "multiple": self.multiple, + } + def _validate_single(self, value: typing.Any): # Normalize empty values to None and handles default normalized_value = self.type.normalize(value) @@ -487,8 +500,16 @@ def _validate_default(self, default: typing.Any, multiple: bool): except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") - if self.choices is not None and default not in self.choices: - raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.") + if self.choices is not None: + if isinstance(default, list): + if not all(d in self.choices for d in default): + raise InvalidParameterError( + f"The default list of values for {self.code} is not included in the provided choices." + ) + elif default not in self.choices: + raise InvalidParameterError( + f"The default value for {self.code} is not included in the provided choices." + ) def parameter_spec(self) -> dict[str, typing.Any]: """Build specification for this parameter, to be provided to the OpenHEXA backend.""" diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index b9b3a6c..1cacbb0 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -167,6 +167,17 @@ def parameters_spec(self) -> list[dict[str, typing.Any]]: """Return the individual specifications of all the parameters of this pipeline.""" return [arg.parameter_spec() for arg in self.parameters] + def to_dict(self): + """Return a dictionary representation of the pipeline.""" + return { + "code": self.code, + "name": self.name, + "parameters": [p.to_dict() for p in self.parameters], + "timeout": self.timeout, + "function": self.function.__dict__ if self.function else None, + "tasks": [t.__dict__ for t in self.tasks], + } + def _get_available_tasks(self) -> list[Task]: return [task for task in self.tasks if task.is_ready()] diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index 9b7d54e..e7d1706 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -14,53 +14,19 @@ import requests -from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound -from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE -from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code +from openhexa.sdk.pipelines.exceptions import PipelineNotFound +from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE, Parameter from .pipeline import Pipeline -@dataclass -class PipelineParameterSpecs: - """Specification of a pipeline parameter.""" - - code: string - type: string - name: string - choices: list[typing.Union[str, int, float]] - help: string - default: typing.Any - required: bool = True - multiple: bool = False - - def __post_init__(self): - """Validate the parameter and set default values.""" - if self.default and self.choices and self.default not in self.choices: - raise ValueError(f"Default value '{self.default}' not in choices {self.choices}") - validate_pipeline_parameter_code(self.code) - if self.required is None: - self.required = True - if self.multiple is None: - self.multiple = False - - @dataclass class Argument: """Argument of a decorator.""" name: string types: list[typing.Any] = field(default_factory=list) - - -@dataclass -class PipelineSpecs: - """Specification of a pipeline.""" - - code: string - name: string - timeout: int = None - parameters: list[PipelineParameterSpecs] = field(default_factory=list) + default_value: typing.Any = None def import_pipeline(pipeline_dir_path: str): @@ -124,10 +90,10 @@ def _get_decorator_arg_value(decorator, arg: Argument, index: int): try: return decorator.args[index].value except IndexError: - return None + return arg.default_value -def _get_decorator_spec(decorator, args: tuple[Argument], key=None): +def _get_decorator_spec(decorator, args: tuple[Argument]): d = {"name": decorator.func.id, "args": {}} for i, arg in enumerate(args): @@ -136,8 +102,8 @@ def _get_decorator_spec(decorator, args: tuple[Argument], key=None): return d -def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: - """Return the pipeline metadata from the pipeline code. +def get_pipeline(pipeline_path: Path) -> Pipeline: + """Return the pipeline with metadata and parameters from the pipeline code. Args: pipeline_path (Path): Path to the pipeline directory @@ -150,7 +116,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Returns ------- - typing.Tuple[PipelineSpecs, typing.List[PipelineParameterSpecs]]: A tuple containing the pipeline specs and the list of parameters specs. + Pipeline: The pipeline object with parameters and metadata. """ tree = ast.parse(open(Path(pipeline_path) / "pipeline.py").read()) pipeline = None @@ -170,7 +136,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("timeout", [ast.Constant]), ), ) - pipeline = PipelineSpecs(**pipeline_decorator_spec["args"]) + pipelines_parameters = [] for parameter_decorator in _get_decorators_by_name(node, "parameter"): param_decorator_spec = _get_decorator_spec( parameter_decorator, @@ -180,19 +146,20 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("name", [ast.Constant]), Argument("choices", [ast.List]), Argument("help", [ast.Constant]), - Argument("default", [ast.Constant]), - Argument("required", [ast.Constant]), - Argument("multiple", [ast.Constant]), + Argument("default", [ast.Constant, ast.List]), + Argument("required", [ast.Constant], default_value=True), + Argument("multiple", [ast.Constant], default_value=False), ), ) + parameter_args = param_decorator_spec["args"] try: - args = param_decorator_spec["args"] - inst = TYPES_BY_PYTHON_TYPE[args["type"]]() - args["type"] = inst.spec_type - - pipeline.parameters.append(PipelineParameterSpecs(**args)) + type_class = TYPES_BY_PYTHON_TYPE[parameter_args.pop("type")]() except KeyError: - raise InvalidParameterError(f"Invalid parameter type {args['type']}") + raise ValueError(f"Unsupported parameter type: {parameter_args['type']}") + parameter = Parameter(type=type_class.expected_type, **parameter_args) + pipelines_parameters.append(parameter) + + pipeline = Pipeline(parameters=pipelines_parameters, function=None, **pipeline_decorator_spec["args"]) if pipeline is None: raise PipelineNotFound("No function with openhexa.sdk pipeline decorator found.") diff --git a/openhexa/utils/stringcase.py b/openhexa/utils/stringcase.py index e79a274..f2acdbe 100644 --- a/openhexa/utils/stringcase.py +++ b/openhexa/utils/stringcase.py @@ -2,6 +2,7 @@ Coming from https://github.com/okunishinishi/python-stringcase """ + import re diff --git a/tests/test_ast.py b/tests/test_ast.py index 6216f88..27c91aa 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,11 +1,10 @@ """Tests related to the parsing of the pipeline code.""" import tempfile -from dataclasses import asdict from unittest import TestCase -from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.exceptions import PipelineNotFound +from openhexa.sdk.pipelines.runtime import get_pipeline class AstTest(TestCase): @@ -18,7 +17,7 @@ def test_pipeline_not_found(self): f.write("print('hello')") with self.assertRaises(PipelineNotFound): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_no_parameters(self): """The file contains a @pipeline decorator but no parameters.""" @@ -35,9 +34,17 @@ def test_pipeline_no_parameters(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} + pipeline.to_dict(), + { + "code": "test", + "name": "Test pipeline", + "function": None, + "tasks": [], + "parameters": [], + "timeout": None, + }, ) def test_pipeline_with_args(self): @@ -55,9 +62,17 @@ def test_pipeline_with_args(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} + pipeline.to_dict(), + { + "code": "test", + "function": None, + "tasks": [], + "name": "Test pipeline", + "parameters": [], + "timeout": None, + }, ) def test_pipeline_with_invalid_parameter_args(self): @@ -77,7 +92,7 @@ def test_pipeline_with_invalid_parameter_args(self): ) ) with self.assertRaises(ValueError): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_with_invalid_pipeline_args(self): """The file contains a @pipeline decorator with invalid value.""" @@ -97,7 +112,7 @@ def test_pipeline_with_invalid_pipeline_args(self): ) ) with self.assertRaises(ValueError): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_with_int_param(self): """The file contains a @pipeline decorator and a @parameter decorator with an int.""" @@ -116,12 +131,14 @@ def test_pipeline_with_int_param(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", + "function": None, + "tasks": [], "parameters": [ { "choices": None, @@ -147,7 +164,7 @@ def test_pipeline_with_multiple_param(self): [ "from openhexa.sdk.pipelines import pipeline, parameter", "", - "@parameter('test_param', name='Test Param', type=int, default=42, help='Param help', multiple=True)", + "@parameter('test_param', name='Test Param', type=int, default=[42], help='Param help', multiple=True)", "@pipeline('test', 'Test pipeline')", "def test_pipeline():", " pass", @@ -155,12 +172,14 @@ def test_pipeline_with_multiple_param(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", + "function": None, + "tasks": [], "parameters": [ { "choices": None, @@ -168,7 +187,7 @@ def test_pipeline_with_multiple_param(self): "code": "test_param", "type": "int", "name": "Test Param", - "default": 42, + "default": [42], "help": "Param help", "required": True, } @@ -195,12 +214,14 @@ def test_pipeline_with_dataset(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", + "function": None, "name": "Test pipeline", + "tasks": [], "parameters": [ { "choices": None, @@ -234,12 +255,14 @@ def test_pipeline_with_choices(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", + "function": None, + "tasks": [], "parameters": [ { "choices": ["a", "b"], @@ -271,9 +294,17 @@ def test_pipeline_with_timeout(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": 42} + pipeline.to_dict(), + { + "code": "test", + "name": "Test pipeline", + "parameters": [], + "timeout": 42, + "function": None, + "tasks": [], + }, ) def test_pipeline_with_bool(self): @@ -293,12 +324,14 @@ def test_pipeline_with_bool(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", + "function": None, + "tasks": [], "parameters": [ { "choices": None, @@ -333,12 +366,14 @@ def test_pipeline_with_multiple_parameters(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", + "function": None, + "tasks": [], "parameters": [ { "choices": None, @@ -382,5 +417,5 @@ def test_pipeline_with_unsupported_parameter(self): ] ) ) - with self.assertRaises(InvalidParameterError): - get_pipeline_metadata(tmpdirname) + with self.assertRaises(KeyError): + get_pipeline(tmpdirname) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index afc3d09..8887f46 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -307,7 +307,7 @@ def test_parameter_validate_multiple(): assert parameter_3.validate([]) == [] # choices - parameter_4 = Parameter("arg4", type=str, choices=["ab", "cd"], multiple=True) + parameter_4 = Parameter("arg4", type=str, default=["ab", "ef"], choices=["ab", "cd", "ef"], multiple=True) assert parameter_4.validate(["ab"]) == ["ab"] assert parameter_4.validate(["ab", "cd"]) == ["ab", "cd"] with pytest.raises(ParameterValueError):