Skip to content

Commit

Permalink
Merge pull request #137 from pastas/dev
Browse files Browse the repository at this point in the history
Release v1.7.0
  • Loading branch information
dbrakenhoff authored Oct 1, 2024
2 parents 45fc56a + ad6e66e commit ea13afe
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 74 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 3.11
cache: "pip"
cache-dependency-path: pyproject.toml

Expand All @@ -47,13 +47,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
pastas-version:
[
"git+https://github.com/pastas/[email protected]",
"git+https://github.com/pastas/[email protected]",
"git+https://github.com/pastas/[email protected]",
"git+https://github.com/pastas/pastas.git@v1.5.0",
"git+https://github.com/pastas/pastas.git@v1.6.0",
"git+https://github.com/pastas/pastas.git@dev",
]
exclude:
Expand All @@ -78,16 +78,13 @@ jobs:
if: ${{ matrix.python-version != '3.12'}}
run: |
pip install --upgrade pip
pip install numpy
pip install ${{ matrix.pastas-version }}
pip install -e .[test]
- name: Install dependencies == PY312
if: ${{ matrix.python-version == '3.12'}}
run: |
pip install --upgrade pip
# TODO: remove numpy pin when numba or ? doesn't crash on NaN being deprecated
pip install "numpy<2.0"
pip install ${{ matrix.pastas-version }}
pip install -e .[test_py312]
Expand Down
2 changes: 1 addition & 1 deletion docs/pstore.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ Bulk operations are also provided for:
* Optimizing pastas Models and storing the results::

# solve models and store result in database
pstore.solve_models(ignore_solver_errors=True, store_result=True)
pstore.solve_models(ignore_solver_errors=True)
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@
}
],
"source": [
"pstore.solve_models(store_result=True, report=False)"
"pstore.solve_models(report=False)"
]
},
{
Expand Down
5 changes: 5 additions & 0 deletions pastastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __repr__(self):
f"{self.n_models} models"
)

@property
def empty(self):
"""Check if the database is empty."""
return not any([self.n_oseries > 0, self.n_stresses > 0, self.n_models > 0])

@abstractmethod
def _get_library(self, libname: str):
"""Get library handle.
Expand Down
63 changes: 40 additions & 23 deletions pastastore/extensions/hpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import hydropandas as hpd
import numpy as np
from hydropandas.io.knmi import get_stations
from hydropandas.io.knmi import _check_latest_measurement_date_de_bilt, get_stations
from pandas import DataFrame, Series, Timedelta, Timestamp
from pastas.timeseries_utils import timestep_weighted_resample
from tqdm.auto import tqdm
Expand Down Expand Up @@ -179,15 +179,15 @@ def add_observation(
action_msg = "added to"

if libname == "oseries":
self._store.upsert_oseries(o.squeeze(), name, metadata=metadata)
self._store.upsert_oseries(o.squeeze(axis=1), name, metadata=metadata)
logger.info(
"%sobservation '%s' %s oseries library.", source, name, action_msg
)
elif libname == "stresses":
if kind is None:
raise ValueError("`kind` must be specified for stresses!")
self._store.upsert_stress(
(o * unit_multiplier).squeeze(), name, kind, metadata=metadata
(o * unit_multiplier).squeeze(axis=1), name, kind, metadata=metadata
)
logger.info(
"%sstress '%s' (kind='%s') %s stresses library.",
Expand Down Expand Up @@ -394,24 +394,32 @@ def update_knmi_meteo(
tmintmax = self._store.get_tmin_tmax("stresses", names=names)

if tmax is not None:
if tmintmax["tmax"].min() > Timestamp(tmax):
logger.info(f"All KNMI stresses are up to date to {tmax}.")
if tmintmax["tmax"].min() >= Timestamp(tmax):
logger.info(f"All KNMI stresses are up to date till {tmax}.")
return

# NOTE: this check is very flaky (15 august 2024), perhaps I annoyed the
# KNMI server... Trying to skip this check and just attempt downloading data.
# maxtmax_rd = _check_latest_measurement_date_de_bilt("RD")
# maxtmax_ev24 = _check_latest_measurement_date_de_bilt("EV24")
maxtmax = Timestamp.today() - Timedelta(days=1)
try:
maxtmax_rd = _check_latest_measurement_date_de_bilt("RD")
maxtmax_ev24 = _check_latest_measurement_date_de_bilt("EV24")
except Exception as e:
# otherwise use maxtmax 28 days (4 weeks) prior to today
logger.warning(
"Could not check latest measurement date in De Bilt: %s" % str(e)
)
maxtmax_rd = maxtmax_ev24 = Timestamp.today() - Timedelta(days=28)
logger.info(
"Using 28 days (4 weeks) prior to today as maxtmax: %s."
% str(maxtmax_rd)
)

for name in tqdm(names, desc="Updating KNMI meteo stresses"):
meteo_var = self._store.stresses.loc[name, "meteo_var"]
# if meteo_var == "RD":
# maxtmax = maxtmax_rd
# elif meteo_var == "EV24":
# maxtmax = maxtmax_ev24
# else:
# maxtmax = maxtmax_rd
if meteo_var == "RD":
maxtmax = maxtmax_rd
elif meteo_var == "EV24":
maxtmax = maxtmax_ev24
else:
maxtmax = maxtmax_rd

# 1 days extra to ensure computation of daily totals using
# timestep_weighted_resample
Expand All @@ -421,7 +429,7 @@ def update_knmi_meteo(
itmin = tmin - Timedelta(days=1)

# ensure 2 observations at least
if itmin >= (maxtmax + Timedelta(days=1)):
if itmin >= (maxtmax - Timedelta(days=1)):
logger.debug("KNMI %s is already up to date." % name)
continue

Expand All @@ -430,20 +438,29 @@ def update_knmi_meteo(
else:
itmax = Timestamp(tmax)

# fix for duplicate station entry in metadata:
stress_station = (
self._store.stresses.at[name, "station"]
if "station" in self._store.stresses.columns
else None
)
if stress_station is not None and not isinstance(
stress_station, (int, np.integer)
):
stress_station = stress_station.squeeze().unique().item()

unit = self._store.stresses.loc[name, "unit"]
kind = self._store.stresses.loc[name, "kind"]
if "station" in self._store.stresses.columns and ~np.isnan(
self._store.stresses.loc[name, "station"]
):
stn = self._store.stresses.loc[name, "station"]
if stress_station is not None:
stn = stress_station
else:
stns = get_stations(meteo_var)
stn_name = name.split("_")[-1].lower()
mask = stns["name"].str.lower().str.replace(" ", "-") == stn_name
if not mask.any():
logger.warning(
f"Station '%s' not found in list of KNMI {meteo_var} stations."
% stn_name
"Station '%s' not found in list of KNMI %s stations."
% (stn_name, meteo_var)
)
continue
stn = stns.loc[mask].index[0]
Expand Down
137 changes: 103 additions & 34 deletions pastastore/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import os
import warnings
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Dict, List, Literal, Optional, Tuple, Union

import numpy as np
Expand All @@ -12,6 +14,7 @@
from packaging.version import parse as parse_version
from pastas.io.pas import pastas_hook
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map

from pastastore.base import BaseConnector
from pastastore.connectors import DictConnector
Expand Down Expand Up @@ -78,6 +81,11 @@ def __init__(
self.plots = Plots(self)
self.yaml = PastastoreYAML(self)

@property
def empty(self) -> bool:
"""Check if the PastaStore is empty."""
return self.conn.empty

def _register_connector_methods(self):
"""Register connector methods (internal method)."""
methods = [
Expand Down Expand Up @@ -1175,18 +1183,19 @@ def add_stressmodel(

def solve_models(
self,
mls: Optional[Union[ps.Model, list, str]] = None,
modelnames: Union[List[str], str, None] = None,
report: bool = False,
ignore_solve_errors: bool = False,
store_result: bool = True,
progressbar: bool = True,
parallel: bool = False,
max_workers: Optional[int] = None,
**kwargs,
) -> None:
"""Solves the models in the store.
Parameters
----------
mls : list of str, optional
modelnames : list of str, optional
list of model names, if None all models in the pastastore
are solved.
report : boolean, optional
Expand All @@ -1196,43 +1205,103 @@ def solve_models(
if True, errors emerging from the solve method are ignored,
default is False which will raise an exception when a model
cannot be optimized
store_result : bool, optional
if True save optimized models, default is True
progressbar : bool, optional
show progressbar, default is True
**kwargs :
show progressbar, default is True.
parallel: bool, optional
if True, solve models in parallel using ProcessPoolExecutor
max_workers: int, optional
maximum number of workers to use in parallel solving, default is
None which will use the number of cores available on the machine
**kwargs : dictionary
arguments are passed to the solve method.
Notes
-----
Users should be aware that parallel solving is platform dependent
and may not always work. The current implementation works well for Linux users.
For Windows users, parallel solving does not work when called directly from
Jupyter Notebooks or IPython. To use parallel solving on Windows, the following
code should be used in a Python file::
from multiprocessing import freeze_support
if __name__ == "__main__":
freeze_support()
pstore.solve_models(parallel=True)
"""
if mls is None:
mls = self.conn.model_names
elif isinstance(mls, ps.Model):
mls = [mls.name]
if "mls" in kwargs:
modelnames = kwargs.pop("mls")
logger.warning("Argument `mls` is deprecated, use `modelnames` instead.")

desc = "Solving models"
for ml_name in tqdm(mls, desc=desc) if progressbar else mls:
ml = self.conn.get_models(ml_name)
modelnames = self.conn._parse_names(modelnames, libname="models")

m_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, pd.Series):
m_kwargs[key] = value.loc[ml_name]
else:
m_kwargs[key] = value
# Convert timestamps
for tstamp in ["tmin", "tmax"]:
if tstamp in m_kwargs:
m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp])
solve_model = partial(
self._solve_model,
report=report,
ignore_solve_errors=ignore_solve_errors,
**kwargs,
)
if self.conn.conn_type != "pas":
parallel = False
logger.error(
"Parallel solving only supported for PasConnector databases."
"Setting parallel to `False`"
)

try:
ml.solve(report=report, **m_kwargs)
if store_result:
self.conn.add_model(ml, overwrite=True)
except Exception as e:
if ignore_solve_errors:
warning = "solve error ignored for -> {}".format(ml.name)
ps.logger.warning(warning)
else:
raise e
if parallel and progressbar:
process_map(solve_model, modelnames, max_workers=max_workers)
elif parallel and not progressbar:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
executor.map(solve_model, modelnames)
else:
for ml_name in (
tqdm(modelnames, desc="Solving models") if progressbar else modelnames
):
solve_model(ml_name=ml_name)

def _solve_model(
self,
ml_name: str,
report: bool = False,
ignore_solve_errors: bool = False,
**kwargs,
) -> None:
"""Solve a model in the store (internal method).
ml_name : list of str, optional
name of a model in the pastastore
report : boolean, optional
determines if a report is printed when the model is solved,
default is False
ignore_solve_errors : boolean, optional
if True, errors emerging from the solve method are ignored,
default is False which will raise an exception when a model
cannot be optimized
**kwargs : dictionary
arguments are passed to the solve method.
"""
ml = self.conn.get_models(ml_name)
m_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, pd.Series):
m_kwargs[key] = value.loc[ml.name]
else:
m_kwargs[key] = value
# Convert timestamps
for tstamp in ["tmin", "tmax"]:
if tstamp in m_kwargs:
m_kwargs[tstamp] = pd.Timestamp(m_kwargs[tstamp])

try:
ml.solve(report=report, **m_kwargs)
except Exception as e:
if ignore_solve_errors:
warning = "Solve error ignored for '%s': %s " % (ml.name, e)
logger.warning(warning)
else:
raise e

self.conn.add_model(ml, overwrite=True)

def model_results(
self,
Expand Down
Loading

0 comments on commit ea13afe

Please sign in to comment.