diff --git a/examples/pipelines/logistic_stats/pipeline.py b/examples/pipelines/logistic_stats/pipeline.py index b882161..0faf346 100644 --- a/examples/pipelines/logistic_stats/pipeline.py +++ b/examples/pipelines/logistic_stats/pipeline.py @@ -113,9 +113,9 @@ def model(dhis2_data: dict[str, typing.Any], gadm_data, worldpop_data): population_df = pd.Series(data=[stat["sum"] for stat in stats], index=administrative_areas.index) population_df = pd.DataFrame({"District": administrative_areas["NAME_2"], "Population": population_df}) corrected_population_df = population_df.copy() - corrected_population_df.loc[ - corrected_population_df["District"].str.startswith("Western"), "District" - ] = "Western Area" + corrected_population_df.loc[corrected_population_df["District"].str.startswith("Western"), "District"] = ( + "Western Area" + ) corrected_population_df = corrected_population_df.groupby("District").sum() diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 203e217..a913b37 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -487,7 +487,12 @@ def _validate_default(self, default: typing.Any, multiple: bool): except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") - if self.choices is not None and default not in self.choices: + if self.choices is not None: + if isinstance(default, list): + if not all(d in self.choices for d in default): + raise InvalidParameterError( + f"The default list of values for {self.code} is not included in the provided choices." + ) raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.") def parameter_spec(self) -> dict[str, typing.Any]: diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index 9b7d54e..ee70a3c 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -36,8 +36,12 @@ class PipelineParameterSpecs: def __post_init__(self): """Validate the parameter and set default values.""" - if self.default and self.choices and self.default not in self.choices: - raise ValueError(f"Default value '{self.default}' not in choices {self.choices}") + if self.default and self.choices: + if isinstance(self.default, list): + if not all(d in self.choices for d in self.default): + raise ValueError(f"Default list of values {self.default} not in choices {self.choices}") + elif self.default not in self.choices: + raise ValueError(f"Default value '{self.default}' not in choices {self.choices}") validate_pipeline_parameter_code(self.code) if self.required is None: self.required = True @@ -180,7 +184,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("name", [ast.Constant]), Argument("choices", [ast.List]), Argument("help", [ast.Constant]), - Argument("default", [ast.Constant]), + Argument("default", [ast.Constant, ast.List]), Argument("required", [ast.Constant]), Argument("multiple", [ast.Constant]), ), diff --git a/openhexa/utils/stringcase.py b/openhexa/utils/stringcase.py index e79a274..f2acdbe 100644 --- a/openhexa/utils/stringcase.py +++ b/openhexa/utils/stringcase.py @@ -2,6 +2,7 @@ Coming from https://github.com/okunishinishi/python-stringcase """ + import re