Skip to content

Commit

Permalink
Merge pull request #46 from Transport-for-the-North/type-hint-fix
Browse files Browse the repository at this point in the history
Type hint fix
  • Loading branch information
Kieran-Fishwick-TfN authored Oct 31, 2024
2 parents 0821db9 + d7c1d36 commit b53ca8a
Show file tree
Hide file tree
Showing 17 changed files with 695 additions and 246 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,6 @@ dmypy.json

# Pyre type checker
.pyre/

# VS code settings
.vscode/
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ lint = [
"mypy>=1.0.0",
"mypy_extensions>=1.0.0",
"pydocstyle[toml]>=6.1.1",
"pylint>=2.14.5",
"pylint>=3.2",
]

test = [
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ matplotlib>=3.8.2
# Ipf requirements
sparse>=0.13.0
numba>=0.60.0

2 changes: 1 addition & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ isort>=5.12.0
mypy>=1.0.0
mypy_extensions>=1.0.0
pydocstyle[toml]>=6.1.1
pylint>=2.14.5
pylint==3.2

# Testing
pytest>=7.4.0
Expand Down
4 changes: 2 additions & 2 deletions src/caf/distribute/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def get_cost_function(self) -> CostFunction:
if self == BuiltInCostFunction.TANNER:
return CostFunction(
name=self.name,
params={"alpha": (-5, 5), "beta": (-5, 5)},
default_params={"alpha": 1, "beta": 1},
params={"alpha": (-1, 1), "beta": (-1, 1)},
default_params={"alpha": 0.1, "beta": -0.1},
function=tanner,
)

Expand Down
5 changes: 4 additions & 1 deletion src/caf/distribute/gravity_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@

# Models
from caf.distribute.gravity_model.single_area import SingleAreaGravityModelCalibrator
from caf.distribute.gravity_model.multi_area import MultiAreaGravityModelCalibrator
from caf.distribute.gravity_model.multi_area import (
MultiAreaGravityModelCalibrator,
GMCalibParams,
)
83 changes: 74 additions & 9 deletions src/caf/distribute/gravity_model/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
"""Core abstract functionality for gravity model classes to build on."""
from __future__ import annotations

# Built-Ins
import abc
import dataclasses
Expand Down Expand Up @@ -90,41 +92,91 @@ class GravityModelCalibrateResults(GravityModelResults):
# Targets
target_cost_distribution: cost_utils.CostDistribution
cost_function: cost_functions.CostFunction
cost_params: dict[str, Any]
cost_params: dict[str | int, Any]

def plot_distributions(self) -> figure.Figure:
"""
Plot a comparison of the achieved and target distributions.
def plot_distributions(self, truncate_last_bin: bool = False) -> figure.Figure:
"""Plot a comparison of the achieved and target distributions.
This method returns a matplotlib figure which can be saved or plotted
as the user decides.
Parameters
----------
truncate_last_bin : bool, optional
whether to truncate the graph to 1.2x the lower bin edge, by default False
Returns
-------
figure.Figure
the plotted distributions
Raises
------
ValueError
when the target and achieved distributions have different binning
"""

fig, ax = plt.subplots(figsize=(10, 6))

errors = []
for attr in ("max_vals", "min_vals", "avg_vals"):
if set(getattr(self.cost_distribution, attr)) != set(
getattr(self.target_cost_distribution, attr)
):
errors.append(attr)

if len(errors) > 0:
raise ValueError(
"To plot distributions, the target and achieved distributions"
" must have the same binning. The distributions have different "
+ " and ".join(errors)
)

max_bin_edge = self.cost_distribution.max_vals
min_bin_edge = self.cost_distribution.min_vals
bin_centres = self.cost_distribution.avg_vals

ax.bar(
self.cost_distribution.avg_vals,
bin_centres,
self.cost_distribution.band_share_vals,
width=self.cost_distribution.max_vals - self.cost_distribution.min_vals,
width=max_bin_edge - min_bin_edge,
label="Achieved Distribution",
color="blue",
alpha=0.7,
)
ax.bar(
self.cost_distribution.avg_vals,
bin_centres,
self.target_cost_distribution.band_share_vals,
width=self.target_cost_distribution.max_vals
- self.target_cost_distribution.min_vals,
width=max_bin_edge - min_bin_edge,
label="Target Distribution",
color="orange",
alpha=0.7,
)

if truncate_last_bin:
top_min_bin = min_bin_edge.max()
ax.set_xlim(0, top_min_bin[-1] * 1.2)
fig.text(0.8, 0.025, f"final bin edge cut from {max_bin_edge.max()}", ha="center")

ax.set_xlabel("Cost")
ax.set_ylabel("Trips")
ax.set_title("Distribution Comparison")
ax.legend()

return fig

@property
def summary(self) -> pd.Series:
"""Summary of the GM calibration parameters as a series.
Outputs the gravity model achieved parameters and the convergence.
Returns
-------
pd.DataFrame
a summary of the calibration
"""


@dataclasses.dataclass
class GravityModelRunResults(GravityModelResults):
Expand Down Expand Up @@ -162,6 +214,19 @@ class GravityModelRunResults(GravityModelResults):
cost_function: Optional[cost_functions.CostFunction] = None
cost_params: Optional[dict[str, Any]] = None

@property
def summary(self) -> pd.Series:
"""Summary of the GM run parameters as a series.
Outputs the gravity model parameters used to generate the distribution.
Returns
-------
pd.DataFrame
a summary of the run
"""
return pd.Series(self.cost_params)


class GravityModelBase(abc.ABC):
"""Base Class for gravity models.
Expand Down
Loading

0 comments on commit b53ca8a

Please sign in to comment.