Skip to content

Commit

Permalink
bug fixes for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kieran-Fishwick-TfN committed Jan 8, 2025
1 parent f3ed371 commit 2d34008
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions tests/gravity_model/test_multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from caf.distribute import utils
from caf.distribute.gravity_model import GravityModelResults


@pytest.fixture(name="cost_from_code", scope="session")
def fixture_code_costs():
np.random.seed(42)
Expand Down Expand Up @@ -177,7 +178,7 @@ def _multi_tld(data_dir, mock_dir):


@pytest.fixture(name="cal_no_furness", scope="session")
def fixture_cal_no_furness(data_dir, infilled, multi_tld, trip_ends, mock_dir):
def fixture_cal_no_furness(infilled, multi_tld, trip_ends, mock_dir):
row_targets = trip_ends["origin"].values
col_targets = trip_ends["destination"].values
model = gm.MultiAreaGravityModelCalibrator(
Expand All @@ -187,13 +188,15 @@ def fixture_cal_no_furness(data_dir, infilled, multi_tld, trip_ends, mock_dir):
cost_function=cost_functions.BuiltInCostFunction.LOG_NORMAL.get_cost_function(),
)
results = model.calibrate(
multi_tld, running_log_path=mock_dir / "temp_log.csv", gm_params=gm.GMCalibParams(furness_jac=False)
multi_tld,
running_log_path=mock_dir / "temp_log.csv",
gm_params=gm.GMCalibParams(furness_jac=False),
)
return results


@pytest.fixture(name="cal_furness", scope="session")
def fixture_cal_furness(self, data_dir, infilled, multi_tld, trip_ends, mock_dir):
def fixture_cal_furness(infilled, multi_tld, trip_ends, mock_dir):
row_targets = trip_ends["origin"].values
col_targets = trip_ends["destination"].values
model = gm.MultiAreaGravityModelCalibrator(
Expand All @@ -209,6 +212,7 @@ def fixture_cal_furness(self, data_dir, infilled, multi_tld, trip_ends, mock_dir
)
return results


class TestUtils:
# TODO(IS) only one test currently so leaving in this file
def test_infill_costs(self, infilled_from_code, infilled_expected):
Expand Down Expand Up @@ -236,13 +240,14 @@ def test_params(self, cal_results, area, request):
assert 0 < sigma < 3
assert 0 < mu < 3


class TestResults:
@pytest.mark.parametrize("results", ["cal_furness", "cal_no_furness"])
@pytest.mark.parametrize("cal_results", ["cal_furness", "cal_no_furness"])
def test_results(self, cal_results, request):
"""Test the the GravityModelResults object methods run as expected"""
"""Test the the GravityModelResults object methods run as expected"""
cal_results = request.getfixturevalue(cal_results)
assert isinstance(cal_results, dict)
for result in cal_results.values():
assert isinstance(result, GravityModelResults)
assert isinstance(result.summary(), pd.Series)
assert isinstance(result.plot_distributions(), plt.Figure)
assert isinstance(result.plot_distributions(), plt.Figure)

0 comments on commit 2d34008

Please sign in to comment.