diff --git a/tests/gravity_model/test_multi_area.py b/tests/gravity_model/test_multi_area.py index 4e4b1a7..274f7da 100644 --- a/tests/gravity_model/test_multi_area.py +++ b/tests/gravity_model/test_multi_area.py @@ -16,6 +16,7 @@ from caf.distribute import utils from caf.distribute.gravity_model import GravityModelResults + @pytest.fixture(name="cost_from_code", scope="session") def fixture_code_costs(): np.random.seed(42) @@ -177,7 +178,7 @@ def _multi_tld(data_dir, mock_dir): @pytest.fixture(name="cal_no_furness", scope="session") -def fixture_cal_no_furness(data_dir, infilled, multi_tld, trip_ends, mock_dir): +def fixture_cal_no_furness(infilled, multi_tld, trip_ends, mock_dir): row_targets = trip_ends["origin"].values col_targets = trip_ends["destination"].values model = gm.MultiAreaGravityModelCalibrator( @@ -187,13 +188,15 @@ def fixture_cal_no_furness(data_dir, infilled, multi_tld, trip_ends, mock_dir): cost_function=cost_functions.BuiltInCostFunction.LOG_NORMAL.get_cost_function(), ) results = model.calibrate( - multi_tld, running_log_path=mock_dir / "temp_log.csv", gm_params=gm.GMCalibParams(furness_jac=False) + multi_tld, + running_log_path=mock_dir / "temp_log.csv", + gm_params=gm.GMCalibParams(furness_jac=False), ) return results @pytest.fixture(name="cal_furness", scope="session") -def fixture_cal_furness(self, data_dir, infilled, multi_tld, trip_ends, mock_dir): +def fixture_cal_furness(infilled, multi_tld, trip_ends, mock_dir): row_targets = trip_ends["origin"].values col_targets = trip_ends["destination"].values model = gm.MultiAreaGravityModelCalibrator( @@ -209,6 +212,7 @@ def fixture_cal_furness(self, data_dir, infilled, multi_tld, trip_ends, mock_dir ) return results + class TestUtils: # TODO(IS) only one test currently so leaving in this file def test_infill_costs(self, infilled_from_code, infilled_expected): @@ -236,13 +240,14 @@ def test_params(self, cal_results, area, request): assert 0 < sigma < 3 assert 0 < mu < 3 + class TestResults: - @pytest.mark.parametrize("results", ["cal_furness", "cal_no_furness"]) + @pytest.mark.parametrize("cal_results", ["cal_furness", "cal_no_furness"]) def test_results(self, cal_results, request): - """Test the the GravityModelResults object methods run as expected""" + """Test the the GravityModelResults object methods run as expected""" cal_results = request.getfixturevalue(cal_results) assert isinstance(cal_results, dict) for result in cal_results.values(): assert isinstance(result, GravityModelResults) assert isinstance(result.summary(), pd.Series) - assert isinstance(result.plot_distributions(), plt.Figure) \ No newline at end of file + assert isinstance(result.plot_distributions(), plt.Figure)