Skip to content

Commit

Permalink
make it work on multiple restart of pl
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 8, 2025
1 parent 6a3e468 commit 02347ea
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions scdataloader/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
"EFO:0030007", # ATACseq
# "EFO:0030062", # slide-seq
],
restart_num: int = 0,
metacell_mode: float = 0.0,
**kwargs,
):
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
restart_num (int, optional): The number of the restart if we are continuing a previous run -> /!\ NEEDS TO BE SET. Defaults to 0.
see @file data.py and @file collator.py for more details about some of the parameters
"""
if collection_name is not None:
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self.train_weights = None
self.train_labels = None
self.nnz = None
self.restart_num = restart_num
self.test_datasets = []
self.test_idx = []
super().__init__()
Expand Down Expand Up @@ -338,6 +340,7 @@ def train_dataloader(self, **kwargs):
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
element_weights=self.nnz,
replacement=self.replacement,
restart_num=self.restart_num,
)
except ValueError as e:
raise ValueError(e + "have you run `datamodule.setup()`?")
Expand Down Expand Up @@ -387,6 +390,7 @@ class LabelWeightedSampler(Sampler[int]):
num_samples: int
nnz: Optional[Sequence[int]]
replacement: bool
restart_num: int
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
# this will result a class-balanced sampling, no matter how imbalance the labels are.

Expand All @@ -397,7 +401,7 @@ def __init__(
num_samples: int,
replacement: bool = True,
element_weights: Sequence[float] = None,
restart_num = 0,
restart_num=0,
) -> None:
"""
Expand All @@ -420,6 +424,7 @@ def __init__(
)
self.replacement = replacement
self.num_samples = num_samples
self.restart_num = restart_num
# list of tensor.
self.klass_indices = [
(self.labels == i_klass).nonzero().squeeze(1)
Expand All @@ -432,6 +437,9 @@ def __iter__(self):
self.label_weights,
num_samples=self.num_samples,
replacement=True,
generator=None
if self.restart_num == 0
else torch.Generator().manual_seed(self.restart_num),
)
sample_indices = torch.empty_like(sample_labels)

Expand All @@ -448,10 +456,17 @@ def __iter__(self):
if not self.replacement and len(klass_index) < len(left_inds)
else len(left_inds),
replacement=self.replacement,
generator=None
if self.restart_num == 0
else torch.Generator().manual_seed(self.restart_num),
)
elif self.replacement:
right_inds = torch.randint(
len(klass_index), size=(len(left_inds),), generator=None
len(klass_index),
size=(len(left_inds),),
generator=None
if self.restart_num == 0
else torch.Generator().manual_seed(self.restart_num),
)
else:
maxelem = (
Expand All @@ -463,6 +478,8 @@ def __iter__(self):
sample_indices[left_inds] = klass_index[right_inds]
# torch shuffle
sample_indices = sample_indices[torch.randperm(len(sample_indices))]
print(sample_indices.tolist()[:10], sample_labels[:10])
# raise Exception("stop")
yield from iter(sample_indices.tolist())

def __len__(self):
Expand Down

0 comments on commit 02347ea

Please sign in to comment.