Skip to content

Commit

Permalink
Tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
isaac-tfn committed Jan 11, 2024
1 parent dfa04f6 commit 88658b0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/caf/distribute/furness.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def doubly_constrained_furness(


@dataclass
class props_input:
class PropsInput:
"""
props: np.ndarray
This is essentially a cost matrix, but costs are replaced by the percentage
Expand Down Expand Up @@ -205,7 +205,7 @@ def cost_to_prop(costs: np.ndarray, bands: pd.DataFrame, val_col: str):


def triply_constrained_furness(
props: list[props_input],
props: list[PropsInput],
row_targets,
col_targets,
max_iters,
Expand All @@ -224,7 +224,7 @@ def triply_constrained_furness(
Parameters
----------
props: list[props_input]
props: list[PropsInput]
A list of info about cost bins. This is produced by cost_to_props
row_targets: np.ndarray
The targets for the rows (origins) in the matrix
Expand Down
14 changes: 11 additions & 3 deletions src/caf/distribute/gravity_model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import pandas as pd
from matplotlib import pyplot as plt, figure
from scipy import optimize
from caf.toolkit import cost_utils, io, timing, BaseConfig
import seaborn as sns
from caf.toolkit import cost_utils, io, timing, BaseConfig


# Local Imports
from caf.distribute import cost_functions
Expand Down Expand Up @@ -60,14 +61,17 @@ class GravityModelResults:

@property
def achieved_rows(self):
"""Return the achieved row totals."""
return self.value_distribution.sum(axis=1)

@property
def achieved_cols(self):
"""Return the achieved column totals."""
return self.value_distribution.sum(axis=0)

@property
def matrix_total(self):
"""Return the total trips in the matrix."""
return self.value_distribution.sum()


Expand Down Expand Up @@ -105,13 +109,15 @@ class GravityModelCalibrateResults(GravityModelResults):
cost_function: cost_functions.CostFunction
cost_params: dict[str, Any]

class output_yaml(BaseConfig):
class OutputYaml(BaseConfig):
"""Class for outputting some data from this class."""
cost_params: dict[str, Any]
cost_function: str
matrix_total: float
cost_convergence: float

def save(self, out_dir: Path):
"""Save method for class"""
out_dir.mkdir(parents=False, exist_ok=True)
achieved = self.cost_distribution.df.copy()
achieved["achieved_normalised_demand"] = (
Expand All @@ -126,7 +132,7 @@ def save(self, out_dir: Path):
target.index = achieved.index
dists_out = achieved.join(target["target_normalised_demand"])
dists_out.to_csv(out_dir / "dist_comparison.csv", index=False)
yaml_output = self.output_yaml(
yaml_output = self.OutputYaml(
cost_params=self.cost_params,
cost_function=self.cost_function.name,
matrix_total=self.matrix_total,
Expand All @@ -136,6 +142,7 @@ def save(self, out_dir: Path):

@property
def target_mean_trip_length(self):
"""Return the mean trip length of the target distribution."""
temp = self.target_cost_distribution.df.copy()
temp["weighted"] = (
temp[self.target_cost_distribution.avg_col]
Expand All @@ -145,6 +152,7 @@ def target_mean_trip_length(self):

@property
def achieved_mean_trip_length(self):
"""Return the mean trip length of the achieved distribution."""
temp = self.cost_distribution.df.copy()
temp["weighted"] = (
temp[self.cost_distribution.avg_col] * temp[self.cost_distribution.trips_col]
Expand Down
7 changes: 4 additions & 3 deletions src/caf/distribute/gravity_model/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class MultiDistInput(BaseConfig):

@property
def log_path(self):
"""Create path to a log file from out_path."""
return self.out_path / "log.csv"


Expand Down Expand Up @@ -344,8 +345,8 @@ def _calibrate(

def calibrate(
self,
update_params: bool = False,
*args,
update_params: bool = False,
**kwargs,
) -> GravityModelCalibrateResults:
"""Find the optimal parameters for self.cost_function.
Expand Down Expand Up @@ -637,7 +638,7 @@ def run(self, triply_constrain: bool = False, xamax: int = 2):
],
val_col="normalised",
)
props = furness.props_input(prop_cost, dist.zones, band_vals)
props = furness.PropsInput(prop_cost, dist.zones, band_vals)
props_list.append(props)
# triply contrained furness on seed matrix
# tol is higher as it is more difficult to converge when triply contrained
Expand All @@ -654,7 +655,7 @@ def run(self, triply_constrain: bool = False, xamax: int = 2):
for i, dist in enumerate(self.dists):
(
single_cost_distribution,
single_achieved_residuals,
_,
single_convergence,
) = core.cost_distribution_stats(
achieved_trip_distribution=new_mat[dist.zones],
Expand Down
8 changes: 4 additions & 4 deletions tests/gravity_model/test_multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def fixture_conf(data_dir, mock_dir):
lookup_cat_col="cat",
lookup_zone_col="zone",
init_params={"mu": 1, "sigma": 2},
log_path=mock_dir / "log.csv",
out_path=mock_dir,
furness_tolerance=0.1,
furness_jac=False,
)
Expand All @@ -140,7 +140,7 @@ def fixture_jac_furn(data_dir, mock_dir):
lookup_cat_col="cat",
lookup_zone_col="zone",
init_params={"mu": 1, "sigma": 2},
log_path=mock_dir / "log.csv",
out_path=mock_dir,
furness_tolerance=0.1,
furness_jac=True,
)
Expand All @@ -158,7 +158,7 @@ def fixture_cal_no_furness(data_dir, infilled, no_furness_jac_conf, trip_ends, m
cost_function=cost_functions.BuiltInCostFunction.LOG_NORMAL.get_cost_function(),
params=no_furness_jac_conf,
)
results = model.calibrate(running_log_path=mock_dir / "temp_log.csv")
results = model.calibrate()
return results


Expand All @@ -173,7 +173,7 @@ def fixture_cal_furness(data_dir, infilled, furness_jac_conf, trip_ends, mock_di
cost_function=cost_functions.BuiltInCostFunction.LOG_NORMAL.get_cost_function(),
params=furness_jac_conf,
)
results = model.calibrate(running_log_path=mock_dir / "temp_log.csv")
results = model.calibrate()
return results


Expand Down

0 comments on commit 88658b0

Please sign in to comment.