From bcc97c534bb8c0d3f98312ea3004d6790acef4c6 Mon Sep 17 00:00:00 2001 From: KieranFishwick Date: Tue, 29 Oct 2024 10:58:41 +0000 Subject: [PATCH] bug fixes --- .../distribute/gravity_model/multi_area.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/caf/distribute/gravity_model/multi_area.py b/src/caf/distribute/gravity_model/multi_area.py index 3ae2ac8..351881b 100644 --- a/src/caf/distribute/gravity_model/multi_area.py +++ b/src/caf/distribute/gravity_model/multi_area.py @@ -93,16 +93,17 @@ class MultiDistInput(BaseConfig): furness_tolerance: float = 1e-6 furness_jac: float = False + @dataclass class GMCalibParams: - furness_jac: bool = False, - diff_step: float = 1e-8, - ftol: float = 1e-4, - xtol: float = 1e-4, - furness_tol=DEFAULT_FURNESS_TOL, - grav_max_iters: int = 100, - failure_tol: float = 0, - default_retry: bool = True, + furness_jac: bool = False + diff_step: float = 1e-8 + ftol: float = 1e-4 + xtol: float = 1e-4 + furness_tol: float = DEFAULT_FURNESS_TOL + grav_max_iters: int = 100 + failure_tol: float = 0 + default_retry: bool = True @dataclass @@ -429,7 +430,7 @@ def __init__( num_zeros = (data == 0).sum() # casting bool as 1, 0 LOG.info( - "There are %s 0s in %s (%s %)", num_zeros, name, (num_zeros / data.size) * 100 + "There are %s 0s in %s (%s percent)", num_zeros, name, (num_zeros / data.size) * 100 ) zero_in_both = np.stack([row_targets == 0, col_targets == 0], axis=1).all(axis=1).sum() @@ -498,7 +499,7 @@ def calibrate( self, distributions: MultiCostDistribution, running_log_path: Path, - gm_params: GMCalibParams, + gm_params: GMCalibParams, verbose: int = 0, **kwargs, ) -> dict[str | int, GravityModelCalibrateResults]: @@ -518,7 +519,7 @@ def calibrate( running_log_path: os.PathLike, path to a csv to log the model iterations and results gm_params: GMCalibParams - defines the detailed parameters, see `GMCalibParams` documentation for more info + defines the detailed parameters, see `GMCalibParams` documentation for more info *args, **kwargs, Returns @@ -533,7 +534,7 @@ def calibrate( `scipy.optimize.least_squares()` `caf.distribute.gravity_model.multi_area.GMCalibParams` """ - + self._validate_running_log(running_log_path) self._initialise_internal_params() @@ -608,7 +609,9 @@ def calibrate( best_convergence = self.achieved_convergence best_params = result.x - if (not all(self.achieved_convergence) >= gm_params.failure_tol) and gm_params.default_retry: + if ( + not all(self.achieved_convergence) >= gm_params.failure_tol + ) and gm_params.default_retry: LOG.info( "%sachieved a convergence of %s, " "however the failure tolerance is set to %s. Trying again with "