Skip to content

Commit

Permalink
created GMCalibParams as an input to calibrate and reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
Kieran-Fishwick-TfN committed Oct 29, 2024
1 parent 54984f4 commit d943bf0
Showing 1 changed file with 57 additions and 63 deletions.
120 changes: 57 additions & 63 deletions src/caf/distribute/gravity_model/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ class MultiDistInput(BaseConfig):
furness_tolerance: float = 1e-6
furness_jac: float = False

@dataclass
class GMCalibParams:
furness_jac: bool = False,
diff_step: float = 1e-8,
ftol: float = 1e-4,
xtol: float = 1e-4,
furness_tol=DEFAULT_FURNESS_TOL,
grav_max_iters: int = 100,
failure_tol: float = 0,
default_retry: bool = True,


@dataclass
class MultiCostDistribution:
Expand Down Expand Up @@ -483,25 +494,54 @@ def _create_seed_matrix(self, cost_distributions, cost_args, params_len):
return base_mat

# pylint: disable=too-many-locals
def _calibrate(
def calibrate(
self,
distributions: MultiCostDistribution,
running_log_path: Path,
furness_jac: bool = False,
diff_step: float = 1e-8,
ftol: float = 1e-4,
xtol: float = 1e-4,
furness_tol=DEFAULT_FURNESS_TOL,
grav_max_iters: int = 100,
failure_tol: float = 0,
default_retry: bool = True,
gm_params: GMCalibParams,
verbose: int = 0,
**kwargs,
) -> dict[str | int, GravityModelCalibrateResults]:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
to fit the distributed row/col targets to `target_cost_distribution`.
NOTE: The achieved distribution is found by accessing self.achieved
distribution of the object this method is called on. The output of
this method shows the distribution and results for each individual TLD.
Parameters
----------
distributions: MultiCostDistribution
distributions to use for the calibrations
running_log_path: os.PathLike,
path to a csv to log the model iterations and results
gm_params: GMCalibParams
defines the detailed parameters, see `GMCalibParams` documentation for more info
*args,
**kwargs,
Returns
-------
dict[str | int, GravityModelCalibrateResults]:
containings the achieved distributions for each tld category. To access
the combined distribution use self.achieved_distribution
See Also
--------
`caf.distribute.furness.doubly_constrained_furness()`
`scipy.optimize.least_squares()`
`caf.distribute.gravity_model.multi_area.GMCalibParams`
"""

self._validate_running_log(running_log_path)
self._initialise_internal_params()

params_len = len(distributions[0].function_params)
ordered_init_params = []

for dist in distributions:
self.cost_function.validate_params(dist.function_params)
params = self._order_cost_params(dist.function_params)
for val in params:
ordered_init_params.append(val)
Expand Down Expand Up @@ -534,10 +574,10 @@ def _calibrate(
gravity_kwargs: dict[str, Any] = {
"running_log_path": running_log_path,
"cost_distributions": distributions,
"diff_step": diff_step,
"diff_step": gm_params.diff_step,
"params_len": params_len,
"furness_jac": furness_jac,
"furness_tol": furness_tol,
"furness_jac": gm_params.furness_jac,
"furness_tol": gm_params.furness_tol,
}
optimise_cost_params = functools.partial(
optimize.least_squares,
Expand All @@ -549,9 +589,9 @@ def _calibrate(
),
jac=self._jacobian_function,
verbose=verbose,
ftol=ftol,
xtol=xtol,
max_nfev=grav_max_iters,
ftol=gm_params.ftol,
xtol=gm_params.xtol,
max_nfev=gm_params.grav_max_iters,
kwargs=gravity_kwargs | kwargs,
)
result = optimise_cost_params(x0=ordered_init_params)
Expand All @@ -568,14 +608,14 @@ def _calibrate(
best_convergence = self.achieved_convergence
best_params = result.x

if (not all(self.achieved_convergence) >= failure_tol) and default_retry:
if (not all(self.achieved_convergence) >= gm_params.failure_tol) and gm_params.default_retry:
LOG.info(
"%sachieved a convergence of %s, "
"however the failure tolerance is set to %s. Trying again with "
"default cost parameters.",
self.unique_id,
self.achieved_convergence,
failure_tol,
gm_params.failure_tol,
)
self._attempt_id += 1
ordered_init_params = self._order_cost_params(self.cost_function.default_params)
Expand Down Expand Up @@ -610,52 +650,6 @@ def _calibrate(
results[dist.name] = result_i
return results

def calibrate(
self,
distributions: MultiCostDistribution,
running_log_path: os.PathLike,
*args,
**kwargs,
) -> dict[str | int, GravityModelCalibrateResults]:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
to fit the distributed row/col targets to `target_cost_distribution`.
NOTE: The achieved distribution is found by accessing self.achieved
distribution of the object this method is called on. The output of
this method shows the distribution and results for each individual TLD.
Parameters
----------
distributions: MultiCostDistribution
distributions to use for the calibrations
running_log_path: os.PathLike,
path to a csv to log the model iterations and results
*args,
**kwargs,
Returns
-------
dict[str | int, GravityModelCalibrateResults]:
An instance of GravityModelCalibrateResults containing the
results of this run.
See Also
--------
`caf.distribute.furness.doubly_constrained_furness()`
`scipy.optimize.least_squares()`
"""
for dist in distributions:
self.cost_function.validate_params(dist.function_params)
self._validate_running_log(running_log_path)
self._initialise_internal_params()
return self._calibrate( # type: ignore
distributions,
running_log_path,
*args,
**kwargs,
)

def _jacobian_function(
self,
init_params: list[float],
Expand Down

0 comments on commit d943bf0

Please sign in to comment.