diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 08a6c71..9384f28 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 cache: "pip" cache-dependency-path: pyproject.toml @@ -47,13 +47,13 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] pastas-version: [ "git+https://github.com/pastas/pastas.git@v0.22.0", "git+https://github.com/pastas/pastas.git@v0.23.1", "git+https://github.com/pastas/pastas.git@v1.0.1", - "git+https://github.com/pastas/pastas.git@v1.5.0", + "git+https://github.com/pastas/pastas.git@v1.6.0", "git+https://github.com/pastas/pastas.git@dev", ] exclude: @@ -78,7 +78,6 @@ jobs: if: ${{ matrix.python-version != '3.12'}} run: | pip install --upgrade pip - pip install numpy pip install ${{ matrix.pastas-version }} pip install -e .[test] @@ -86,8 +85,6 @@ jobs: if: ${{ matrix.python-version == '3.12'}} run: | pip install --upgrade pip - # TODO: remove numpy pin when numba or ? doesn't crash on NaN being deprecated - pip install "numpy<2.0" pip install ${{ matrix.pastas-version }} pip install -e .[test_py312] diff --git a/docs/pstore.rst b/docs/pstore.rst index ca0d65b..36d9c19 100644 --- a/docs/pstore.rst +++ b/docs/pstore.rst @@ -33,4 +33,4 @@ Bulk operations are also provided for: * Optimizing pastas Models and storing the results:: # solve models and store result in database - pstore.solve_models(ignore_solver_errors=True, store_result=True) + pstore.solve_models(ignore_solver_errors=True) diff --git a/examples/notebooks/ex01_introduction_to_pastastore_databases.ipynb b/examples/notebooks/ex01_introduction_to_pastastore_databases.ipynb index 4fee571..0957bcd 100644 --- a/examples/notebooks/ex01_introduction_to_pastastore_databases.ipynb +++ b/examples/notebooks/ex01_introduction_to_pastastore_databases.ipynb @@ -1363,7 +1363,7 @@ } ], "source": [ - "pstore.solve_models(store_result=True, report=False)" + "pstore.solve_models(report=False)" ] }, { diff --git a/pastastore/base.py b/pastastore/base.py index 1e71486..21a2f72 100644 --- a/pastastore/base.py +++ b/pastastore/base.py @@ -56,6 +56,11 @@ def __repr__(self): f"{self.n_models} models" ) + @property + def empty(self): + """Check if the database is empty.""" + return not any([self.n_oseries > 0, self.n_stresses > 0, self.n_models > 0]) + @abstractmethod def _get_library(self, libname: str): """Get library handle. diff --git a/pastastore/extensions/hpd.py b/pastastore/extensions/hpd.py index 4617416..77a282f 100644 --- a/pastastore/extensions/hpd.py +++ b/pastastore/extensions/hpd.py @@ -12,7 +12,7 @@ import hydropandas as hpd import numpy as np -from hydropandas.io.knmi import get_stations +from hydropandas.io.knmi import _check_latest_measurement_date_de_bilt, get_stations from pandas import DataFrame, Series, Timedelta, Timestamp from pastas.timeseries_utils import timestep_weighted_resample from tqdm.auto import tqdm @@ -179,7 +179,7 @@ def add_observation( action_msg = "added to" if libname == "oseries": - self._store.upsert_oseries(o.squeeze(), name, metadata=metadata) + self._store.upsert_oseries(o.squeeze(axis=1), name, metadata=metadata) logger.info( "%sobservation '%s' %s oseries library.", source, name, action_msg ) @@ -187,7 +187,7 @@ def add_observation( if kind is None: raise ValueError("`kind` must be specified for stresses!") self._store.upsert_stress( - (o * unit_multiplier).squeeze(), name, kind, metadata=metadata + (o * unit_multiplier).squeeze(axis=1), name, kind, metadata=metadata ) logger.info( "%sstress '%s' (kind='%s') %s stresses library.", @@ -394,24 +394,32 @@ def update_knmi_meteo( tmintmax = self._store.get_tmin_tmax("stresses", names=names) if tmax is not None: - if tmintmax["tmax"].min() > Timestamp(tmax): - logger.info(f"All KNMI stresses are up to date to {tmax}.") + if tmintmax["tmax"].min() >= Timestamp(tmax): + logger.info(f"All KNMI stresses are up to date till {tmax}.") return - # NOTE: this check is very flaky (15 august 2024), perhaps I annoyed the - # KNMI server... Trying to skip this check and just attempt downloading data. - # maxtmax_rd = _check_latest_measurement_date_de_bilt("RD") - # maxtmax_ev24 = _check_latest_measurement_date_de_bilt("EV24") - maxtmax = Timestamp.today() - Timedelta(days=1) + try: + maxtmax_rd = _check_latest_measurement_date_de_bilt("RD") + maxtmax_ev24 = _check_latest_measurement_date_de_bilt("EV24") + except Exception as e: + # otherwise use maxtmax 28 days (4 weeks) prior to today + logger.warning( + "Could not check latest measurement date in De Bilt: %s" % str(e) + ) + maxtmax_rd = maxtmax_ev24 = Timestamp.today() - Timedelta(days=28) + logger.info( + "Using 28 days (4 weeks) prior to today as maxtmax: %s." + % str(maxtmax_rd) + ) for name in tqdm(names, desc="Updating KNMI meteo stresses"): meteo_var = self._store.stresses.loc[name, "meteo_var"] - # if meteo_var == "RD": - # maxtmax = maxtmax_rd - # elif meteo_var == "EV24": - # maxtmax = maxtmax_ev24 - # else: - # maxtmax = maxtmax_rd + if meteo_var == "RD": + maxtmax = maxtmax_rd + elif meteo_var == "EV24": + maxtmax = maxtmax_ev24 + else: + maxtmax = maxtmax_rd # 1 days extra to ensure computation of daily totals using # timestep_weighted_resample @@ -421,7 +429,7 @@ def update_knmi_meteo( itmin = tmin - Timedelta(days=1) # ensure 2 observations at least - if itmin >= (maxtmax + Timedelta(days=1)): + if itmin >= (maxtmax - Timedelta(days=1)): logger.debug("KNMI %s is already up to date." % name) continue @@ -430,20 +438,29 @@ def update_knmi_meteo( else: itmax = Timestamp(tmax) + # fix for duplicate station entry in metadata: + stress_station = ( + self._store.stresses.at[name, "station"] + if "station" in self._store.stresses.columns + else None + ) + if stress_station is not None and not isinstance( + stress_station, (int, np.integer) + ): + stress_station = stress_station.squeeze().unique().item() + unit = self._store.stresses.loc[name, "unit"] kind = self._store.stresses.loc[name, "kind"] - if "station" in self._store.stresses.columns and ~np.isnan( - self._store.stresses.loc[name, "station"] - ): - stn = self._store.stresses.loc[name, "station"] + if stress_station is not None: + stn = stress_station else: stns = get_stations(meteo_var) stn_name = name.split("_")[-1].lower() mask = stns["name"].str.lower().str.replace(" ", "-") == stn_name if not mask.any(): logger.warning( - f"Station '%s' not found in list of KNMI {meteo_var} stations." - % stn_name + "Station '%s' not found in list of KNMI %s stations." + % (stn_name, meteo_var) ) continue stn = stns.loc[mask].index[0] diff --git a/pastastore/store.py b/pastastore/store.py index 97c71a3..b2a2954 100644 --- a/pastastore/store.py +++ b/pastastore/store.py @@ -4,6 +4,8 @@ import logging import os import warnings +from concurrent.futures import ProcessPoolExecutor +from functools import partial from typing import Dict, List, Literal, Optional, Tuple, Union import numpy as np @@ -12,6 +14,7 @@ from packaging.version import parse as parse_version from pastas.io.pas import pastas_hook from tqdm.auto import tqdm +from tqdm.contrib.concurrent import process_map from pastastore.base import BaseConnector from pastastore.connectors import DictConnector @@ -78,6 +81,11 @@ def __init__( self.plots = Plots(self) self.yaml = PastastoreYAML(self) + @property + def empty(self) -> bool: + """Check if the PastaStore is empty.""" + return self.conn.empty + def _register_connector_methods(self): """Register connector methods (internal method).""" methods = [ @@ -1175,18 +1183,19 @@ def add_stressmodel( def solve_models( self, - mls: Optional[Union[ps.Model, list, str]] = None, + modelnames: Union[List[str], str, None] = None, report: bool = False, ignore_solve_errors: bool = False, - store_result: bool = True, progressbar: bool = True, + parallel: bool = False, + max_workers: Optional[int] = None, **kwargs, ) -> None: """Solves the models in the store. Parameters ---------- - mls : list of str, optional + modelnames : list of str, optional list of model names, if None all models in the pastastore are solved. report : boolean, optional @@ -1196,43 +1205,103 @@ def solve_models( if True, errors emerging from the solve method are ignored, default is False which will raise an exception when a model cannot be optimized - store_result : bool, optional - if True save optimized models, default is True progressbar : bool, optional - show progressbar, default is True - **kwargs : + show progressbar, default is True. + parallel: bool, optional + if True, solve models in parallel using ProcessPoolExecutor + max_workers: int, optional + maximum number of workers to use in parallel solving, default is + None which will use the number of cores available on the machine + **kwargs : dictionary arguments are passed to the solve method. + + Notes + ----- + Users should be aware that parallel solving is platform dependent + and may not always work. The current implementation works well for Linux users. + For Windows users, parallel solving does not work when called directly from + Jupyter Notebooks or IPython. To use parallel solving on Windows, the following + code should be used in a Python file:: + + from multiprocessing import freeze_support + + if __name__ == "__main__": + freeze_support() + pstore.solve_models(parallel=True) """ - if mls is None: - mls = self.conn.model_names - elif isinstance(mls, ps.Model): - mls = [mls.name] + if "mls" in kwargs: + modelnames = kwargs.pop("mls") + logger.warning("Argument `mls` is deprecated, use `modelnames` instead.") - desc = "Solving models" - for ml_name in tqdm(mls, desc=desc) if progressbar else mls: - ml = self.conn.get_models(ml_name) + modelnames = self.conn._parse_names(modelnames, libname="models") - m_kwargs = {} - for key, value in kwargs.items(): - if isinstance(value, pd.Series): - m_kwargs[key] = value.loc[ml_name] - else: - m_kwargs[key] = value - # Convert timestamps - for tstamp in ["tmin", "tmax"]: - if tstamp in m_kwargs: - m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp]) + solve_model = partial( + self._solve_model, + report=report, + ignore_solve_errors=ignore_solve_errors, + **kwargs, + ) + if self.conn.conn_type != "pas": + parallel = False + logger.error( + "Parallel solving only supported for PasConnector databases." + "Setting parallel to `False`" + ) - try: - ml.solve(report=report, **m_kwargs) - if store_result: - self.conn.add_model(ml, overwrite=True) - except Exception as e: - if ignore_solve_errors: - warning = "solve error ignored for -> {}".format(ml.name) - ps.logger.warning(warning) - else: - raise e + if parallel and progressbar: + process_map(solve_model, modelnames, max_workers=max_workers) + elif parallel and not progressbar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + executor.map(solve_model, modelnames) + else: + for ml_name in ( + tqdm(modelnames, desc="Solving models") if progressbar else modelnames + ): + solve_model(ml_name=ml_name) + + def _solve_model( + self, + ml_name: str, + report: bool = False, + ignore_solve_errors: bool = False, + **kwargs, + ) -> None: + """Solve a model in the store (internal method). + + ml_name : list of str, optional + name of a model in the pastastore + report : boolean, optional + determines if a report is printed when the model is solved, + default is False + ignore_solve_errors : boolean, optional + if True, errors emerging from the solve method are ignored, + default is False which will raise an exception when a model + cannot be optimized + **kwargs : dictionary + arguments are passed to the solve method. + """ + ml = self.conn.get_models(ml_name) + m_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, pd.Series): + m_kwargs[key] = value.loc[ml.name] + else: + m_kwargs[key] = value + # Convert timestamps + for tstamp in ["tmin", "tmax"]: + if tstamp in m_kwargs: + m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp]) + + try: + ml.solve(report=report, **m_kwargs) + except Exception as e: + if ignore_solve_errors: + warning = "Solve error ignored for '%s': %s " % (ml.name, e) + logger.warning(warning) + else: + raise e + + self.conn.add_model(ml, overwrite=True) def model_results( self, diff --git a/pastastore/styling.py b/pastastore/styling.py index eed2afa..b7b5c62 100644 --- a/pastastore/styling.py +++ b/pastastore/styling.py @@ -1,8 +1,8 @@ """Module containing dataframe styling functions.""" -import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +from matplotlib.colors import rgb2hex def float_styler(val, norm, cmap=None): @@ -26,12 +26,12 @@ def float_styler(val, norm, cmap=None): ----- Given some dataframe - >>> df.map(float_styler, subset=["some column"], norm=norm, cmap=cmap) + >>> df.style.map(float_styler, subset=["some column"], norm=norm, cmap=cmap) """ if cmap is None: cmap = plt.get_cmap("RdYlBu") bg = cmap(norm(val)) - color = mpl.colors.rgb2hex(bg) + color = rgb2hex(bg) c = "White" if np.mean(bg[:3]) < 0.4 else "Black" return f"background-color: {color}; color: {c}" @@ -53,15 +53,48 @@ def boolean_styler(b): ----- Given some dataframe - >>> df.map(boolean_styler, subset=["some column"]) + >>> df.style.map(boolean_styler, subset=["some column"]) """ if b: return ( - f"background-color: {mpl.colors.rgb2hex((231/255, 255/255, 239/255))}; " + f"background-color: {rgb2hex((231/255, 255/255, 239/255))}; " "color: darkgreen" ) else: return ( - f"background-color: {mpl.colors.rgb2hex((255/255, 238/255, 238/255))}; " + f"background-color: {rgb2hex((255/255, 238/255, 238/255))}; " "color: darkred" ) + + +def boolean_row_styler(row, column): + """Styler function to color rows based on the value in column. + + Parameters + ---------- + row : pd.Series + row in dataframe + column : str + column name to get boolean value for styling + + Returns + ------- + str + css for styling dataframe row + + Usage + ----- + Given some dataframe + + >>> df.style.apply(boolean_row_styler, column="boolean_column", axis=1) + """ + if row[column]: + return ( + f"background-color: {rgb2hex((231/255, 255/255, 239/255))}; " + "color: darkgreen", + ) * row.size + else: + return ( + f"background-color: {rgb2hex((255/255, 238/255, 238/255))}; " + "color: darkred", + ) * row.size diff --git a/pastastore/version.py b/pastastore/version.py index d871b96..58656a5 100644 --- a/pastastore/version.py +++ b/pastastore/version.py @@ -9,7 +9,7 @@ PASTAS_LEQ_022 = PASTAS_VERSION <= parse_version("0.22.0") PASTAS_GEQ_150 = PASTAS_VERSION >= parse_version("1.5.0") -__version__ = "1.6.1" +__version__ = "1.7.0" def show_versions(optional=False) -> None: diff --git a/pyproject.toml b/pyproject.toml index 175e0e4..f5b7a69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ test = [ "codacy-coverage", ] test_py312 = [ - "pastastore[lint,optional]", + "pastastore[lint,optional]", # no arcticdb "hydropandas[full]", "coverage", "codecov", diff --git a/tests/test_003_pastastore.py b/tests/test_003_pastastore.py index 3fc3b35..8021b50 100644 --- a/tests/test_003_pastastore.py +++ b/tests/test_003_pastastore.py @@ -195,13 +195,18 @@ def test_iter_models(request, pstore): def test_solve_models_and_get_stats(request, pstore): depends(request, [f"test_create_models[{pstore.type}]"]) _ = pstore.solve_models( - ignore_solve_errors=False, progressbar=False, store_result=True + ignore_solve_errors=False, progressbar=False, parallel=False ) stats = pstore.get_statistics(["evp", "aic"], progressbar=False) assert stats.index.size == 2 @pytest.mark.dependency +def test_solve_models_parallel(request, pstore): + depends(request, [f"test_create_models[{pstore.type}]"]) + _ = pstore.solve_models(ignore_solve_errors=False, progressbar=False, parallel=True) + + def test_apply(request, pstore): depends(request, [f"test_solve_models_and_get_stats[{pstore.type}]"])