Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kieran-Fishwick-TfN committed Oct 29, 2024
1 parent d943bf0 commit bcc97c5
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/caf/distribute/gravity_model/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,17 @@ class MultiDistInput(BaseConfig):
furness_tolerance: float = 1e-6
furness_jac: float = False


@dataclass
class GMCalibParams:
furness_jac: bool = False,
diff_step: float = 1e-8,
ftol: float = 1e-4,
xtol: float = 1e-4,
furness_tol=DEFAULT_FURNESS_TOL,
grav_max_iters: int = 100,
failure_tol: float = 0,
default_retry: bool = True,
furness_jac: bool = False
diff_step: float = 1e-8
ftol: float = 1e-4
xtol: float = 1e-4
furness_tol: float = DEFAULT_FURNESS_TOL
grav_max_iters: int = 100
failure_tol: float = 0
default_retry: bool = True


@dataclass
Expand Down Expand Up @@ -429,7 +430,7 @@ def __init__(
num_zeros = (data == 0).sum() # casting bool as 1, 0

LOG.info(
"There are %s 0s in %s (%s %)", num_zeros, name, (num_zeros / data.size) * 100
"There are %s 0s in %s (%s percent)", num_zeros, name, (num_zeros / data.size) * 100
)

zero_in_both = np.stack([row_targets == 0, col_targets == 0], axis=1).all(axis=1).sum()
Expand Down Expand Up @@ -498,7 +499,7 @@ def calibrate(
self,
distributions: MultiCostDistribution,
running_log_path: Path,
gm_params: GMCalibParams,
gm_params: GMCalibParams,
verbose: int = 0,
**kwargs,
) -> dict[str | int, GravityModelCalibrateResults]:
Expand All @@ -518,7 +519,7 @@ def calibrate(
running_log_path: os.PathLike,
path to a csv to log the model iterations and results
gm_params: GMCalibParams
defines the detailed parameters, see `GMCalibParams` documentation for more info
defines the detailed parameters, see `GMCalibParams` documentation for more info
*args,
**kwargs,
Returns
Expand All @@ -533,7 +534,7 @@ def calibrate(
`scipy.optimize.least_squares()`
`caf.distribute.gravity_model.multi_area.GMCalibParams`
"""

self._validate_running_log(running_log_path)
self._initialise_internal_params()

Expand Down Expand Up @@ -608,7 +609,9 @@ def calibrate(
best_convergence = self.achieved_convergence
best_params = result.x

if (not all(self.achieved_convergence) >= gm_params.failure_tol) and gm_params.default_retry:
if (
not all(self.achieved_convergence) >= gm_params.failure_tol
) and gm_params.default_retry:
LOG.info(
"%sachieved a convergence of %s, "
"however the failure tolerance is set to %s. Trying again with "
Expand Down

0 comments on commit bcc97c5

Please sign in to comment.