Skip to content

Commit

Permalink
moved _run method to bias correction interface ``AbstractBiasCorr…
Browse files Browse the repository at this point in the history
…ection``
  • Loading branch information
bnb32 committed Jan 5, 2025
1 parent 7f92049 commit b22f421
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 225 deletions.
167 changes: 167 additions & 0 deletions sup3r/bias/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Bias correction class interface."""

import logging
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np

from sup3r.preprocessing import DataHandler

logger = logging.getLogger(__name__)


class AbstractBiasCorrection(ABC):
"""Minimal interface for bias correction classes"""

@abstractmethod
def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""

def _run_in_parallel(self, task_kwargs, max_workers=None):
"""
Execute a list of tasks in parallel using ``ProcessPoolExecutor``.
Parameters
----------
task_kwargs : dictionary
A dictionary of keyword argument dictionaries for a single call to
``task_function``.
max_workers : int, optional
The maximum number of workers to use. If None, it uses all
available.
Returns
-------
results : dictionary
A dictionary of results from the executed tasks with the same keys
as ``task_kwargs``.
"""

results = {}
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {
exe.submit(self._run_single, **kwargs): bias_gid
for bias_gid, kwargs in task_kwargs.items()
}
for future in as_completed(futures):
bias_gid = futures[future]
results[bias_gid] = future.result()
return results

def _run(
self,
out,
max_workers=None,
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
**kwargs_extras,
):
"""Run correction factor calculations for every site in the bias
dataset
Parameters
----------
out : dict
Dictionary of arrays to fill with bias correction factors.
max_workers : int
Number of workers to run in parallel. 1 is serial and None is all
available.
daily_reduction : None | str
Option to do a reduction of the hourly+ source base data to daily
data. Can be None (no reduction, keep source time frequency), "avg"
(daily average), "max" (daily max), "min" (daily min),
"sum" (daily sum/total)
fill_extend : bool
Flag to fill data past distance_upper_bound using spatial nearest
neighbor. If False, the extended domain will be left as NaN.
smooth_extend : float
Option to smooth the scalar/adder data outside of the spatial
domain set by the distance_upper_bound input. This alleviates the
weird seams far from the domain of interest. This value is the
standard deviation for the gaussian_filter kernel
smooth_interior : float
Option to smooth the scalar/adder data within the valid spatial
domain. This can reduce the affect of extreme values within
aggregations over large number of pixels.
kwargs_extras: dict
Additional kwargs that get sent to ``_run_single`` e.g.
daily_reduction='avg', zero_rate_threshold=1.157e-7
Returns
-------
out : dict
Dictionary of values defining the mean/std of the bias + base data
and correction factors to correct the biased data like: bias_data *
scalar + adder. Each value is of shape (lat, lon, time).
"""
self.bad_bias_gids = []

task_kwargs = self._get_run_kwargs(**kwargs_extras)
# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
results = {
bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh)
for bias_gid, kwargs in task_kwargs.items()
}
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = self._run_in_parallel(
task_kwargs, max_workers=max_workers
)
for i, (bias_gid, single_out) in enumerate(results.items()):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
for key, arr in single_out.items():
out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

return self.fill_and_smooth(
out, fill_extend, smooth_extend, smooth_interior
)

@abstractmethod
def run(
self,
fp_out=None,
max_workers=None,
daily_reduction='avg',
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
):
"""Run correction factor calculations for every site in the bias
dataset"""

@classmethod
@abstractmethod
def _run_single(
cls,
bias_data,
base_fps,
bias_feature,
base_dset,
base_gid,
base_handler,
daily_reduction,
bias_ti,
decimals,
base_dh_inst=None,
match_zero_rate=False,
):
"""Find the bias correction factors at a single site"""
114 changes: 0 additions & 114 deletions sup3r/bias/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from sup3r.utilities import VERSION_RECORD, ModuleName
from sup3r.utilities.cli import BaseCLI

from .utilities import run_in_parallel

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -779,115 +777,3 @@ def _reduce_base_data(
assert base_data.shape == daily_ti.shape, msg

return base_data, daily_ti

def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""
task_kwargs = {}
for bias_gid in self.bias_meta.index:
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_kwargs[bias_gid] = {
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
**kwargs_extras
}
return task_kwargs

def _run(
self,
max_workers=None,
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
**kwargs_extras
):
"""Run correction factor calculations for every site in the bias
dataset
Parameters
----------
fp_out : str | None
Optional .h5 output file to write scalar and adder arrays.
max_workers : int
Number of workers to run in parallel. 1 is serial and None is all
available.
daily_reduction : None | str
Option to do a reduction of the hourly+ source base data to daily
data. Can be None (no reduction, keep source time frequency), "avg"
(daily average), "max" (daily max), "min" (daily min),
"sum" (daily sum/total)
fill_extend : bool
Flag to fill data past distance_upper_bound using spatial nearest
neighbor. If False, the extended domain will be left as NaN.
smooth_extend : float
Option to smooth the scalar/adder data outside of the spatial
domain set by the distance_upper_bound input. This alleviates the
weird seams far from the domain of interest. This value is the
standard deviation for the gaussian_filter kernel
smooth_interior : float
Option to smooth the scalar/adder data within the valid spatial
domain. This can reduce the affect of extreme values within
aggregations over large number of pixels.
kwargs_extras: dict
Additional kwargs that get sent to ``_run_single`` e.g.
daily_reduction='avg', zero_rate_threshold=1.157e-7
Returns
-------
out : dict
Dictionary of values defining the mean/std of the bias + base
data and the scalar + adder factors to correct the biased data
like: bias_data * scalar + adder. Each value is of shape
(lat, lon, time).
"""
self.bad_bias_gids = []

task_kwargs = self._get_run_kwargs(**kwargs_extras)
# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
results = {
bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh)
for bias_gid, kwargs in task_kwargs.items()
}
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_kwargs, max_workers=max_workers
)
for i, (bias_gid, single_out) in enumerate(results.items()):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
)

return self.out
32 changes: 31 additions & 1 deletion sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import numpy as np
from scipy import stats

from .abstract import AbstractBiasCorrection
from .base import DataRetrievalBase
from .mixins import FillAndSmoothMixin

logger = logging.getLogger(__name__)


class LinearCorrection(FillAndSmoothMixin, DataRetrievalBase):
class LinearCorrection(
AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase
):
"""Calculate linear correction *scalar +adder factors to bias correct data
This calculation operates on single bias sites for the full time series of
Expand Down Expand Up @@ -159,6 +162,32 @@ def write_outputs(self, fp_out, out):
'Wrote scalar adder factors to file: {}'.format(fp_out)
)

def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""
task_kwargs = {}
for bias_gid in self.bias_meta.index:
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_kwargs[bias_gid] = {
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
**kwargs_extras,
}
return task_kwargs

def run(
self,
fp_out=None,
Expand Down Expand Up @@ -212,6 +241,7 @@ def run(
)
)
self.out = self._run(
out=self.out,
max_workers=max_workers,
daily_reduction=daily_reduction,
fill_extend=fill_extend,
Expand Down
Loading

0 comments on commit b22f421

Please sign in to comment.