Skip to content

Commit

Permalink
removed GravityModelResults sub classes since run and calibrate have …
Browse files Browse the repository at this point in the history
…the same parameters
  • Loading branch information
Kieran-Fishwick-TfN committed Jan 8, 2025
1 parent 64894c1 commit e062bb3
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 175 deletions.
2 changes: 0 additions & 2 deletions src/caf/distribute/gravity_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
# Dataclasses
from caf.distribute.gravity_model.core import GravityModelResults
from caf.distribute.gravity_model.core import GravityModelRunResults
from caf.distribute.gravity_model.core import GravityModelCalibrateResults
from caf.distribute.gravity_model.multi_area import (
MultiCostDistribution,
MultiDistInput,
Expand Down
150 changes: 6 additions & 144 deletions src/caf/distribute/gravity_model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,14 @@


# # # CLASSES # # #
class GravityModelResults(abc.ABC):
"""A collection of results from a run of the Gravity Model.
target_cost_distribution:
The cost distribution the gravity model was aiming for during its run.
cost_function:
The cost function used in the gravity model run.
cost_params:
The cost parameters used with the cost_function to achieve the results.
"""
@dataclasses.dataclass
class GravityModelResults:
"""A collection of results from the Gravity Model."""

cost_distribution: cost_utils.CostDistribution
"""The achieved cost distribution of the run."""
"""The achieved cost distribution of the results."""
target_cost_distribution: cost_utils.CostDistribution
"""The taregt cost distribution used to obtain the results."""
cost_convergence: float
"""The achieved cost convergence value of the run. If
`target_cost_distribution` is not set, then this should be 0.
Expand All @@ -56,59 +49,6 @@ class GravityModelResults(abc.ABC):
cost_params: dict[str | int, Any]
"""The final/used cost parameters used by the cost function."""

def __init__(
self,
cost_distribution: cost_utils.CostDistribution,
cost_convergence: float,
value_distribution: np.ndarray,
cost_function: cost_functions.CostFunction,
cost_params: dict[str | int, Any],
) -> None:

self.cost_distribution = cost_distribution
self.cost_convergence = cost_convergence
self.value_distribution = value_distribution
self.cost_function = cost_function
self.cost_params = cost_params

@abc.abstractmethod
def plot_distributions(self, truncate_last_bin: bool = False) -> figure.Figure:
"""Plot the distributions associated with the results.
Parameters
----------
truncate_last_bin : bool, optional
whether to truncate the graph to 1.2x the lower bin edge, by default False
"""

@property
@abc.abstractmethod
def summary(self) -> pd.Series:
"""Summary of the results parameters as a series."""


class GravityModelCalibrateResults(GravityModelResults):
"""A collection of results from a calibration of the Gravity Model."""

# Targets
target_cost_distribution: cost_utils.CostDistribution
"""The cost distribution the gravity model was aiming for during its run."""

def __init__(
self,
cost_distribution: cost_utils.CostDistribution,
cost_convergence: float,
value_distribution: np.ndarray,
target_cost_distribution: cost_utils.CostDistribution,
cost_function: cost_functions.CostFunction,
cost_params: dict[str | int, Any],
) -> None:

super().__init__(
cost_distribution, cost_convergence, value_distribution, cost_function, cost_params
)
self.target_cost_distribution = target_cost_distribution

def plot_distributions(self, truncate_last_bin: bool = False) -> figure.Figure:
"""Plot a comparison of the achieved and target distributions.
Expand Down Expand Up @@ -197,84 +137,6 @@ def summary(self) -> pd.Series:
return pd.Series(output_params)


class GravityModelRunResults(GravityModelResults):
"""A collection of results from a run of the Gravity Model."""

def __init__(
self,
cost_distribution: cost_utils.CostDistribution,
cost_convergence: float,
value_distribution: np.ndarray,
cost_function: cost_functions.CostFunction,
cost_params: dict[int | str, Any],
) -> None:
super().__init__(
cost_distribution, cost_convergence, value_distribution, cost_function, cost_params
)

@property
def summary(self) -> pd.Series:
"""Summary of the GM run parameters as a series.
Outputs the gravity model parameters used to generate the distribution.
Returns
-------
pd.DataFrame
a summary of the run
"""
return pd.Series(self.cost_params)

def plot_distributions(self, truncate_last_bin: bool = False) -> figure.Figure:
"""Plot a comparison of the achieved and target distributions.
This method returns a matplotlib figure which can be saved or plotted
as the user decides.
Parameters
----------
truncate_last_bin : bool, optional
whether to truncate the graph to 1.2x the lower bin edge, by default False
Returns
-------
figure.Figure
the plotted distributions
Raises
------
ValueError
when the target and achieved distributions have different binning
"""

fig, ax = plt.subplots(figsize=(10, 6))

max_bin_edge = self.cost_distribution.max_vals
min_bin_edge = self.cost_distribution.min_vals
bin_centres = (max_bin_edge + min_bin_edge) / 2

ax.bar(
bin_centres,
self.cost_distribution.band_share_vals,
width=max_bin_edge - min_bin_edge,
label="Achieved Distribution",
color="blue",
alpha=0.7,
)

if truncate_last_bin:
top_min_bin = min_bin_edge.max()
ax.set_xlim(0, top_min_bin[-1] * 1.2)
fig.text(0.8, 0.025, f"final bin edge cut from {max_bin_edge.max()}", ha="center")

ax.set_xlabel("Cost")
ax.set_ylabel("Trips")
ax.set_title("Distribution Achieved")
ax.legend()

return fig


class GravityModelBase(abc.ABC):
"""Base Class for gravity models.
Expand Down
19 changes: 10 additions & 9 deletions src/caf/distribute/gravity_model/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
# Local Imports
from caf.distribute import cost_functions, furness
from caf.distribute.gravity_model import core
from caf.distribute.gravity_model.core import (
GravityModelCalibrateResults,
GravityModelRunResults,
)
from caf.distribute.gravity_model.core import GravityModelResults

# # # CONSTANTS # # #
LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -585,7 +582,7 @@ def calibrate(
gm_params: GMCalibParams,
verbose: int = 0,
**kwargs,
) -> dict[str | int, GravityModelCalibrateResults]:
) -> dict[str | int, GravityModelResults]:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
Expand All @@ -608,7 +605,7 @@ def calibrate(
Returns
-------
dict[str | int, GravityModelCalibrateResults]:
dict[str | int, GravityModelResults]:
containings the achieved distributions for each tld category. To access
the combined distribution use self.achieved_distribution
Expand Down Expand Up @@ -723,7 +720,7 @@ def calibrate(
assert self.achieved_cost_dist is not None
results = {}
for i, dist in enumerate(distributions):
result_i = GravityModelCalibrateResults(
result_i = GravityModelResults(
cost_distribution=self.achieved_cost_dist[i],
cost_convergence=self.achieved_convergence[dist.name],
value_distribution=self.achieved_distribution[dist.zones],
Expand Down Expand Up @@ -877,7 +874,7 @@ def run(
distributions: MultiCostDistribution,
running_log_path: Path,
furness_tol: float = 1e-6,
) -> dict[int | str, GravityModelCalibrateResults]:
) -> dict[int | str, GravityModelResults]:
"""
Run the gravity_model without calibrating.
Expand All @@ -894,6 +891,10 @@ def run(
tolerance for difference in target and achieved value,
at which to stop furnessing, by default 1e-6
Returns
-------
dict[int | str, GravityModelResults]
The results of the gravity model run for each distribution
"""
params_len = len(distributions[0].function_params)
cost_args = []
Expand All @@ -912,7 +913,7 @@ def run(
assert self.achieved_cost_dist is not None
results = {}
for i, dist in enumerate(distributions):
result_i = GravityModelRunResults(
result_i = GravityModelResults(
cost_distribution=self.achieved_cost_dist[i],
cost_convergence=self.achieved_convergence[dist.name],
value_distribution=self.achieved_distribution[dist.zones],
Expand Down
22 changes: 9 additions & 13 deletions src/caf/distribute/gravity_model/single_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@

# Local Imports
from caf.distribute import cost_functions, furness
from caf.distribute.gravity_model import (
GravityModelCalibrateResults,
GravityModelRunResults,
core,
)
from caf.distribute.gravity_model import GravityModelResults, core

# # # CONSTANTS # # #
LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -222,7 +218,7 @@ def _calibrate(
default_retry: bool = True,
verbose: int = 0,
**kwargs,
) -> GravityModelCalibrateResults:
) -> GravityModelResults:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
Expand Down Expand Up @@ -400,7 +396,7 @@ def _calibrate(

# Populate internal arguments with optimal run results.
assert self.achieved_cost_dist is not None
return GravityModelCalibrateResults(
return GravityModelResults(
cost_distribution=self.achieved_cost_dist,
cost_convergence=self.achieved_convergence,
value_distribution=self.achieved_distribution,
Expand All @@ -415,7 +411,7 @@ def calibrate(
running_log_path: os.PathLike,
*args,
**kwargs,
) -> GravityModelCalibrateResults:
) -> GravityModelResults:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
Expand Down Expand Up @@ -519,7 +515,7 @@ def calibrate_with_perceived_factors(
*args,
failure_tol: float = 0.5,
**kwargs,
) -> GravityModelCalibrateResults:
) -> GravityModelResults:
"""Find the optimal parameters for self.cost_function.
Optimal parameters are found using `scipy.optimize.least_squares`
Expand Down Expand Up @@ -650,7 +646,7 @@ def run(
running_log_path: os.PathLike,
target_cost_distribution: Optional[cost_utils.CostDistribution] = None,
**kwargs,
) -> GravityModelRunResults:
) -> GravityModelResults:
"""Run the gravity model with set cost parameters.
This function will run a single iteration of the gravity model using
Expand Down Expand Up @@ -700,7 +696,7 @@ def run(
)

assert self.achieved_cost_dist is not None
return GravityModelRunResults(
return GravityModelResults(
cost_distribution=self.achieved_cost_dist,
cost_convergence=self.achieved_convergence,
value_distribution=self.achieved_distribution,
Expand All @@ -716,7 +712,7 @@ def run_with_perceived_factors(
target_cost_distribution: cost_utils.CostDistribution,
target_cost_convergence: float = 0.9,
**kwargs,
) -> GravityModelRunResults:
) -> GravityModelResults:
"""Run the gravity model with set cost parameters.
This function will run a single iteration of the gravity model using
Expand Down Expand Up @@ -794,7 +790,7 @@ def run_with_perceived_factors(
)

assert self.achieved_cost_dist is not None
return GravityModelRunResults(
return GravityModelResults(
cost_distribution=self.achieved_cost_dist,
cost_convergence=self.achieved_convergence,
value_distribution=self.achieved_distribution,
Expand Down
10 changes: 3 additions & 7 deletions tests/gravity_model/test_single_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
# Local Imports
from caf.distribute import cost_functions
from caf.distribute.gravity_model import (
GravityModelCalibrateResults,
GravityModelResults,
GravityModelRunResults,
SingleAreaGravityModelCalibrator,
)

Expand Down Expand Up @@ -104,7 +102,7 @@ def create_and_run_gravity_model(
furness_max_iters: int = 1000,
furness_tol: float = 1e-3,
use_perceived_factors: bool = False,
) -> GravityModelRunResults:
) -> GravityModelResults:
gm = SingleAreaGravityModelCalibrator(
row_targets=self.row_targets,
col_targets=self.col_targets,
Expand Down Expand Up @@ -142,7 +140,7 @@ def create_and_calibrate_gravity_model(
furness_max_iters: int = 1000,
furness_tol: float = 1e-3,
use_perceived_factors: bool = False,
) -> GravityModelCalibrateResults:
) -> GravityModelResults:
gm = SingleAreaGravityModelCalibrator(
row_targets=self.row_targets,
col_targets=self.col_targets,
Expand Down Expand Up @@ -246,9 +244,7 @@ def get_optimal_params(self) -> dict[str, Any]:
"""Get the optimal parameters from disk"""
return self.best_params

def assert_results(
self, gm_results: GravityModelRunResults | GravityModelCalibrateResults
) -> None:
def assert_results(self, gm_results: GravityModelResults) -> None:
"""Assert that all the results are as expected"""
# Check the scalar values
for key, val in self.best_params.items():
Expand Down

0 comments on commit e062bb3

Please sign in to comment.