diff --git a/src/caf/distribute/gravity_model/multi_area.py b/src/caf/distribute/gravity_model/multi_area.py index 1e1fbe9..26bad53 100644 --- a/src/caf/distribute/gravity_model/multi_area.py +++ b/src/caf/distribute/gravity_model/multi_area.py @@ -96,7 +96,7 @@ class MultiDistInput(BaseConfig): @dataclass class GMCalibParams: - """Parameters required for the multi tld gravity mode calibrate method + """Parameters required for the multi tld gravity mode calibrate method. All of the arguements have defaults, i.e. you can create the default object with no arguements. HOWEVER, read the parameter section below, it is important to @@ -166,7 +166,7 @@ class GMCalibParams: @dataclass class MultiCostDistribution: - """Cost distributions to be used for the multi-cost distribution gravity model + """Cost distributions to be used for the multi-cost distribution gravity model. Parameters ---------- @@ -192,7 +192,7 @@ def from_pandas( lookup_cat_col: str = "category", lookup_zone_col: str = "zone_id", ) -> MultiCostDistribution: - """constructor using pandas dataframes + """Build class using pandas dataframes. Parameters ---------- @@ -266,7 +266,10 @@ def from_pandas( @classmethod def validate(cls, distributions: list[MGMCostDistribution]): - """Validates the distributions passed + """Checks the distributions passed. + + Raises an error if duplicate zones are found across different + distributions. Parameters ---------- @@ -298,17 +301,43 @@ def validate(cls, distributions: list[MGMCostDistribution]): raise ValueError("duplicate found in the distribution zone definition") def __iter__(self) -> Iterator[MGMCostDistribution]: + """Iterates through each distribution. + + Yields + ------ + Iterator[MGMCostDistribution] + iterator for the cost distributions. + """ yield from self.distributions - def __getitem__(self, x) -> MGMCostDistribution: + def __getitem__(self, x: int) -> MGMCostDistribution: + """Retrieves the xth distribution. + + Parameters + ---------- + x : int + index of the distribution to retreive + + Returns + ------- + MGMCostDistribution + the xth distrubtion. + """ return self.distributions[x] def __len__(self) -> int: - return len(self.distributions) + """The number of distrubtions. + Returns + ------- + int + The number of distrubtions. + """ + return len(self.distributions) def copy(self) -> MultiCostDistribution: - """ + """A wrapper around deepcopy. + Returns ------- MultiCostDistribution @@ -369,7 +398,7 @@ def from_pandas( lookup_cat_col: str = "category", lookup_zone_col: str = "zone_id", ) -> MGMCostDistribution: - """constructor that uses pandas dataframes and series + """Build using pandas dataframes and series. Parameters ---------- @@ -574,6 +603,7 @@ def calibrate( defines the detailed parameters, see `GMCalibParams` documentation for more info *args, **kwargs, + Returns ------- dict[str | int, GravityModelCalibrateResults]: @@ -832,7 +862,7 @@ def _gravity_function( self._loop_start_time = timing.current_milli_time() self.achieved_cost_dist: list[cost_utils.CostDistribution] = distributions - self.achieved_convergence: dict[str|int, float] = convergences + self.achieved_convergence: dict[str | int, float] = convergences self.achieved_distribution = matrix achieved_residuals = np.concatenate(residuals)