diff --git a/src/caf/distribute/gravity_model/multi_area.py b/src/caf/distribute/gravity_model/multi_area.py index c2836c3..3ae2ac8 100644 --- a/src/caf/distribute/gravity_model/multi_area.py +++ b/src/caf/distribute/gravity_model/multi_area.py @@ -93,6 +93,17 @@ class MultiDistInput(BaseConfig): furness_tolerance: float = 1e-6 furness_jac: float = False +@dataclass +class GMCalibParams: + furness_jac: bool = False, + diff_step: float = 1e-8, + ftol: float = 1e-4, + xtol: float = 1e-4, + furness_tol=DEFAULT_FURNESS_TOL, + grav_max_iters: int = 100, + failure_tol: float = 0, + default_retry: bool = True, + @dataclass class MultiCostDistribution: @@ -483,25 +494,54 @@ def _create_seed_matrix(self, cost_distributions, cost_args, params_len): return base_mat # pylint: disable=too-many-locals - def _calibrate( + def calibrate( self, distributions: MultiCostDistribution, running_log_path: Path, - furness_jac: bool = False, - diff_step: float = 1e-8, - ftol: float = 1e-4, - xtol: float = 1e-4, - furness_tol=DEFAULT_FURNESS_TOL, - grav_max_iters: int = 100, - failure_tol: float = 0, - default_retry: bool = True, + gm_params: GMCalibParams, verbose: int = 0, **kwargs, ) -> dict[str | int, GravityModelCalibrateResults]: + """Find the optimal parameters for self.cost_function. + + Optimal parameters are found using `scipy.optimize.least_squares` + to fit the distributed row/col targets to `target_cost_distribution`. + + NOTE: The achieved distribution is found by accessing self.achieved + distribution of the object this method is called on. The output of + this method shows the distribution and results for each individual TLD. + + Parameters + ---------- + distributions: MultiCostDistribution + distributions to use for the calibrations + running_log_path: os.PathLike, + path to a csv to log the model iterations and results + gm_params: GMCalibParams + defines the detailed parameters, see `GMCalibParams` documentation for more info + *args, + **kwargs, + Returns + ------- + dict[str | int, GravityModelCalibrateResults]: + containings the achieved distributions for each tld category. To access + the combined distribution use self.achieved_distribution + + See Also + -------- + `caf.distribute.furness.doubly_constrained_furness()` + `scipy.optimize.least_squares()` + `caf.distribute.gravity_model.multi_area.GMCalibParams` + """ + + self._validate_running_log(running_log_path) + self._initialise_internal_params() + params_len = len(distributions[0].function_params) ordered_init_params = [] for dist in distributions: + self.cost_function.validate_params(dist.function_params) params = self._order_cost_params(dist.function_params) for val in params: ordered_init_params.append(val) @@ -534,10 +574,10 @@ def _calibrate( gravity_kwargs: dict[str, Any] = { "running_log_path": running_log_path, "cost_distributions": distributions, - "diff_step": diff_step, + "diff_step": gm_params.diff_step, "params_len": params_len, - "furness_jac": furness_jac, - "furness_tol": furness_tol, + "furness_jac": gm_params.furness_jac, + "furness_tol": gm_params.furness_tol, } optimise_cost_params = functools.partial( optimize.least_squares, @@ -549,9 +589,9 @@ def _calibrate( ), jac=self._jacobian_function, verbose=verbose, - ftol=ftol, - xtol=xtol, - max_nfev=grav_max_iters, + ftol=gm_params.ftol, + xtol=gm_params.xtol, + max_nfev=gm_params.grav_max_iters, kwargs=gravity_kwargs | kwargs, ) result = optimise_cost_params(x0=ordered_init_params) @@ -568,14 +608,14 @@ def _calibrate( best_convergence = self.achieved_convergence best_params = result.x - if (not all(self.achieved_convergence) >= failure_tol) and default_retry: + if (not all(self.achieved_convergence) >= gm_params.failure_tol) and gm_params.default_retry: LOG.info( "%sachieved a convergence of %s, " "however the failure tolerance is set to %s. Trying again with " "default cost parameters.", self.unique_id, self.achieved_convergence, - failure_tol, + gm_params.failure_tol, ) self._attempt_id += 1 ordered_init_params = self._order_cost_params(self.cost_function.default_params) @@ -610,52 +650,6 @@ def _calibrate( results[dist.name] = result_i return results - def calibrate( - self, - distributions: MultiCostDistribution, - running_log_path: os.PathLike, - *args, - **kwargs, - ) -> dict[str | int, GravityModelCalibrateResults]: - """Find the optimal parameters for self.cost_function. - - Optimal parameters are found using `scipy.optimize.least_squares` - to fit the distributed row/col targets to `target_cost_distribution`. - - NOTE: The achieved distribution is found by accessing self.achieved - distribution of the object this method is called on. The output of - this method shows the distribution and results for each individual TLD. - - Parameters - ---------- - distributions: MultiCostDistribution - distributions to use for the calibrations - running_log_path: os.PathLike, - path to a csv to log the model iterations and results - *args, - **kwargs, - Returns - ------- - dict[str | int, GravityModelCalibrateResults]: - An instance of GravityModelCalibrateResults containing the - results of this run. - - See Also - -------- - `caf.distribute.furness.doubly_constrained_furness()` - `scipy.optimize.least_squares()` - """ - for dist in distributions: - self.cost_function.validate_params(dist.function_params) - self._validate_running_log(running_log_path) - self._initialise_internal_params() - return self._calibrate( # type: ignore - distributions, - running_log_path, - *args, - **kwargs, - ) - def _jacobian_function( self, init_params: list[float],