diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index d29d0eac..bea22b26 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -332,6 +332,16 @@ class KerasRegressorMetrics(str, Enum): mae = "mae" +class SMOTETomekSamplingStrategy(str, Enum): + """Sampling strategies available for SMOTETomek.""" + + minority = "minority" + not_minority = "not minority" + not_majority = "not majority" + all = "all" + auto = "auto" + + INPUT_FILE_OPTION = Annotated[ Path, typer.Option( @@ -3026,6 +3036,49 @@ def gamma_overlay_cli(input_rasters: INPUT_FILES_ARGUMENT, output_raster: OUTPUT # WOFE # TODO +# --- TRAINING DATA TOOLS --- + + +# BALANCE SMOTETOMEK +@app.command() +def balance_data_cli( + input_rasters: INPUT_FILES_ARGUMENT, + input_labels: INPUT_FILE_OPTION, + output_raster: OUTPUT_FILE_OPTION, + output_labels: OUTPUT_FILE_OPTION, + sampling_strategy_literal: Annotated[SMOTETomekSamplingStrategy, typer.Option()] = SMOTETomekSamplingStrategy.auto, + sampling_strategy_float: Optional[float] = None, + random_state: Optional[int] = None, +): + """Resample feature data using SMOTETomek. + + Parameter sampling_strategy_float will override sampling_strategy_literal if given. + """ + from eis_toolkit.prediction.machine_learning_general import prepare_data_for_ml + from eis_toolkit.training_data_tools.class_balancing import balance_SMOTETomek + + X, y, profile, _ = prepare_data_for_ml(input_rasters, input_labels) + typer.echo("Progress: 30%") + + if sampling_strategy_float is not None: + sampling_strategy = sampling_strategy_float + else: + sampling_strategy = sampling_strategy_literal + + X_res, y_res = balance_SMOTETomek(X, y, sampling_strategy, random_state) + typer.echo("Progress 80%") + + with rasterio.open(output_raster, "w", **profile) as dst: + dst.write(X_res, 1) + + with rasterio.open(output_labels, "w", **profile) as dst: + dst.write(y_res, 1) + typer.echo("Progress: 100%") + typer.echo( + f"Balancing data completed, writing resampled feature data to {output_raster} \ + and corresponding labels to {output_labels}." + ) + # --- TRANSFORMATIONS --- diff --git a/eis_toolkit/training_data_tools/class_balancing.py b/eis_toolkit/training_data_tools/class_balancing.py index 3120bf38..f8b80c10 100644 --- a/eis_toolkit/training_data_tools/class_balancing.py +++ b/eis_toolkit/training_data_tools/class_balancing.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd from beartype import beartype -from beartype.typing import Optional, Union +from beartype.typing import Literal, Optional, Union from imblearn.combine import SMOTETomek from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -11,24 +11,27 @@ def balance_SMOTETomek( X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray], - sampling_strategy: Union[float, str, dict] = "auto", + sampling_strategy: Union[float, Literal["minority", "not minority", "not majority", "all", "auto"], dict] = "auto", random_state: Optional[int] = None, ) -> tuple[Union[pd.DataFrame, np.ndarray], Union[pd.Series, np.ndarray]]: - """Balances the classes of input dataset using SMOTETomek resampling method. + """ + Balances the classes of input dataset using SMOTETomek resampling method. + + For more information about Imblearn SMOTETomek read the documentation here: + https://imbalanced-learn.org/stable/references/generated/imblearn.combine.SMOTETomek.html. Args: - X: The feature matrix (input data as a DataFrame). - y: The target labels corresponding to the feature matrix. + X: Input feature data to be sampled. + y: Target labels corresponding to the input features. sampling_strategy: Parameter controlling how to perform the resampling. If float, specifies the ratio of samples in minority class to samples of majority class, if str, specifies classes to be resampled ("minority", "not minority", "not majority", "all", "auto"), if dict, the keys should be targeted classes and values the desired number of samples for the class. Defaults to "auto", which will resample all classes except the majority class. - random_state: Parameter controlling randomization of the algorithm. Can be given a seed (number). - Defaults to None, which randomizes the seed. + random_state: Seed for random number generation. Defaults to None. Returns: - Resampled feature matrix and target labels. + Resampled feature data and target labels. Raises: NonMatchingParameterLengthsException: If X and y have different length. diff --git a/tests/training_data_tools/class_balancing_test.py b/tests/training_data_tools/class_balancing_test.py index 6703d433..b2ce3d49 100644 --- a/tests/training_data_tools/class_balancing_test.py +++ b/tests/training_data_tools/class_balancing_test.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from beartype.roar import BeartypeCallHintParamViolation from sklearn.datasets import make_classification from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -37,6 +38,6 @@ def test_invalid_label_length(): def test_invalid_sampling_strategy(): - """Test that invalid value for sampling strategy raises the correct exception (generated by imblearn).""" - with pytest.raises(ValueError): + """Test that invalid value for sampling strategy raises the correct exception.""" + with pytest.raises(BeartypeCallHintParamViolation): balance_SMOTETomek(X, y, sampling_strategy="invalid_strategy")