diff --git a/rul_datasets/adaption.py b/rul_datasets/adaption.py index 9425ddf..1751fe4 100644 --- a/rul_datasets/adaption.py +++ b/rul_datasets/adaption.py @@ -31,6 +31,8 @@ class DomainAdaptionDataModule(pl.LightningDataModule): >>> source = rul_datasets.RulDataModule(fd1, 32) >>> target = rul_datasets.RulDataModule(fd2, 32) >>> dm = rul_datasets.DomainAdaptionDataModule(source, target) + >>> dm.prepare_data() + >>> dm.setup() >>> train_1_2 = dm.train_dataloader() >>> val_1, val_2 = dm.val_dataloader() >>> test_1, test_2 = dm.test_dataloader() @@ -231,9 +233,11 @@ class LatentAlignDataModule(DomainAdaptionDataModule): >>> import rul_datasets >>> fd1 = rul_datasets.CmapssReader(fd=1, window_size=20) >>> fd2 = rul_datasets.CmapssReader(fd=2, percent_broken=0.8) - >>> source = rul_datasets.RulDataModule(fd1, 32) - >>> target = rul_datasets.RulDataModule(fd2, 32) - >>> dm = rul_datasets.LatentAlignDataModule(source, target) + >>> src = rul_datasets.RulDataModule(fd1, 32) + >>> trg = rul_datasets.RulDataModule(fd2, 32) + >>> dm = rul_datasets.LatentAlignDataModule(src, trg, split_by_max_rul=125) + >>> dm.prepare_data() + >>> dm.setup() >>> train_1_2 = dm.train_dataloader() >>> val_1, val_2 = dm.val_dataloader() >>> test_1, test_2 = dm.test_dataloader() diff --git a/rul_datasets/baseline.py b/rul_datasets/baseline.py index 1f73956..e5f1411 100644 --- a/rul_datasets/baseline.py +++ b/rul_datasets/baseline.py @@ -25,6 +25,8 @@ class BaselineDataModule(pl.LightningDataModule): >>> cmapss = rul_datasets.reader.CmapssReader(fd=1) >>> dm = rul_datasets.RulDataModule(cmapss, batch_size=32) >>> baseline_dm = rul_datasets.BaselineDataModule(dm) + >>> baseline_dm.prepare_data() + >>> baseline_dm.setup() >>> train_fd1 = baseline_dm.train_dataloader() >>> val_fd1 = baseline_dm.val_dataloader() >>> test_fd1, test_fd2, test_fd3, test_fd4 = baseline_dm.test_dataloader() diff --git a/rul_datasets/reader/abstract.py b/rul_datasets/reader/abstract.py index 71ca0cc..313415e 100644 --- a/rul_datasets/reader/abstract.py +++ b/rul_datasets/reader/abstract.py @@ -23,6 +23,11 @@ class AbstractReader(metaclass=abc.ABCMeta): Examples: >>> import rul_datasets >>> class MyReader(rul_datasets.reader.AbstractReader): + ... @property + ... def dataset_name(self): + ... return "my_dataset" + ... + ... @property ... def fds(self): ... return [1] ... @@ -41,7 +46,7 @@ class AbstractReader(metaclass=abc.ABCMeta): >>> my_reader = MyReader(fd=1) >>> features, targets = my_reader.load_split("dev") >>> features[0].shape - torch.Size([100, 2, 30]) + (100, 2, 30) """ fd: int diff --git a/rul_datasets/ssl.py b/rul_datasets/ssl.py index 68192b5..4b8ad55 100644 --- a/rul_datasets/ssl.py +++ b/rul_datasets/ssl.py @@ -22,6 +22,8 @@ class SemiSupervisedDataModule(pl.LightningDataModule): >>> labeled = rul_datasets.RulDataModule(fd1, 32) >>> unlabeled = rul_datasets.RulDataModule(fd1_complement, 32) >>> dm = rul_datasets.SemiSupervisedDataModule(labeled, unlabeled) + >>> dm.prepare_data() + >>> dm.setup() >>> train_ssl = dm.train_dataloader() >>> val = dm.val_dataloader() >>> test = dm.test_dataloader()