Skip to content

Commit

Permalink
fix: fixes the default value to be of type list
Browse files Browse the repository at this point in the history
Closes HEXA-1037 : Default values for parameters set as "multiple" fails on run of the pipeline
  • Loading branch information
nazarfil committed Sep 24, 2024
1 parent 80a73be commit f8c7a4f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
6 changes: 3 additions & 3 deletions examples/pipelines/logistic_stats/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def model(dhis2_data: dict[str, typing.Any], gadm_data, worldpop_data):
population_df = pd.Series(data=[stat["sum"] for stat in stats], index=administrative_areas.index)
population_df = pd.DataFrame({"District": administrative_areas["NAME_2"], "Population": population_df})
corrected_population_df = population_df.copy()
corrected_population_df.loc[
corrected_population_df["District"].str.startswith("Western"), "District"
] = "Western Area"
corrected_population_df.loc[corrected_population_df["District"].str.startswith("Western"), "District"] = (
"Western Area"
)

corrected_population_df = corrected_population_df.groupby("District").sum()

Expand Down
7 changes: 6 additions & 1 deletion openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def _validate_default(self, default: typing.Any, multiple: bool):
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

if self.choices is not None and default not in self.choices:
if self.choices is not None:
if isinstance(default, list):
if not all(d in self.choices for d in default):
raise InvalidParameterError(
f"The default list of values for {self.code} is not included in the provided choices."
)
raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.")

def parameter_spec(self) -> dict[str, typing.Any]:
Expand Down
10 changes: 7 additions & 3 deletions openhexa/sdk/pipelines/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ class PipelineParameterSpecs:

def __post_init__(self):
"""Validate the parameter and set default values."""
if self.default and self.choices and self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
if self.default and self.choices:
if isinstance(self.default, list):
if not all(d in self.choices for d in self.default):
raise ValueError(f"Default list of values {self.default} not in choices {self.choices}")
elif self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
validate_pipeline_parameter_code(self.code)
if self.required is None:
self.required = True
Expand Down Expand Up @@ -180,7 +184,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("name", [ast.Constant]),
Argument("choices", [ast.List]),
Argument("help", [ast.Constant]),
Argument("default", [ast.Constant]),
Argument("default", [ast.Constant, ast.List]),
Argument("required", [ast.Constant]),
Argument("multiple", [ast.Constant]),
),
Expand Down
1 change: 1 addition & 0 deletions openhexa/utils/stringcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Coming from https://github.com/okunishinishi/python-stringcase
"""

import re


Expand Down

0 comments on commit f8c7a4f

Please sign in to comment.