Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross validation built into the framework #20544

Open
svechinsky opened this issue Jan 13, 2025 · 2 comments
Open

Cross validation built into the framework #20544

svechinsky opened this issue Jan 13, 2025 · 2 comments
Labels
feature Is an improvement or enhancement

Comments

@svechinsky
Copy link

svechinsky commented Jan 13, 2025

Description & Motivation

Cross validation is standard practice in many cases.
I personally use it to have higher confidence in models with borderline amounts of data.

There is currently no way built in way to do this using lightning/lightning CLI
What I'm currently doing is passing the fold index manually as an argument and then the datamodule handle the fold creation.

Pitch

The ideal scenario in my opinion would be to have fold_enabled data modules that generate different train/validation/test sets based on a fold index parameter passed by lightning.

This will allow the user to maintain full control over fold selection and splitting while treating the cross fold validation as a single run.

Alternatives

No response

Additional context

No response

cc @lantiga @Borda

@svechinsky svechinsky added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Jan 13, 2025
@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Jan 13, 2025
@lantiga
Copy link
Collaborator

lantiga commented Jan 13, 2025

hey @svechinsky this is a nice way to do it: https://gist.github.com/ashleve/ac511f08c0d29e74566900fd3efbb3ec

import lightning as L
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold


class ProteinsKFoldDataModule(L.LightningDataModule):
    def __init__(
            self,
            data_dir: str = "data/",
            k: int = 1,  # fold number
            split_seed: int = 12345,  # split needs to be always the same for correct cross validation
            num_splits: int = 10,
            batch_size: int = 32,
            num_workers: int = 0,
            pin_memory: bool = False
        ):
        super().__init__()
        
        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

        # num_splits = 10 means our dataset will be split to 10 parts
        # so we train on 90% of the data and validate on 10%
        assert 1 <= self.k <= self.num_splits, "incorrect fold number"
        
        # data transformations
        self.transforms = None

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None

    @property
    def num_node_features() -> int:
        return 4

    @property
    def num_classes() -> int:
        return 2

    def setup(self, stage=None):
        if not self.data_train and not self.data_val:
            dataset_full = TUDataset(self.hparams.data_dir, name="PROTEINS", use_node_attr=True, transform=self.transforms)

            # choose fold to train on
            kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
            all_splits = [k for k in kf.split(dataset_full)]
            train_indexes, val_indexes = all_splits[self.hparams.k]
            train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

            self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes]

    def train_dataloader(self):
        return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, 
                          num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size,
                          num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory)
results = []
nums_folds = 10
split_seed = 12345

for k in range(nums_folds):
    datamodule = ProteinsKFoldDataModule(k=k, num_folds=num_folds, split_seed=split_seed, ...)
    datamodule.prepare_data()
    datamodule.setup()

    # here we train the model on given split...
    model = ...
    ...
    trainer = L.Trainer(...)
    trainer.fit(model, datamodule)

    results.append(score)

score = sum(results) / num_folds

Typically you'd instantiate a separate model and trainer for each fold, which is not too bad. We could add a version of this example to the docs.

Here's an example we already have with Fabric: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/kfold_cv

What do you think?

@svechinsky
Copy link
Author

That exactly what I ended up doing it!
I think that since this is a common pattern it would be great to have it as part of the framework.
Especially since doing it this way precludes usage of the lightning CLI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

No branches or pull requests

2 participants