Skip to content

Commit

Permalink
list the full dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 7, 2025
1 parent 7ebccde commit 6a3e468
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions scdataloader/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,8 @@ def __repr__(self):
f"\ttest datasets={str(self.test_datasets)},\n"
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
f"\tclss_to_weight={self.clss_to_weight}\n"
+ (
"\twith train_dataset size of=("
+ str((self.train_weights != 0).sum())
+ ")\n)"
)
if self.train_weights is not None
+ ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
if self.idx_full is not None
else ")"
)

Expand Down Expand Up @@ -401,12 +397,14 @@ def __init__(
num_samples: int,
replacement: bool = True,
element_weights: Sequence[float] = None,
restart_num = 0,
) -> None:
"""
:param label_weights: list(len=num_classes)[float], weights for each class.
:param labels: list(len=dataset_len)[int], labels of a dataset.
:param num_samples: number of samples.
:param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
"""

super(LabelWeightedSampler, self).__init__(None)
Expand Down

0 comments on commit 6a3e468

Please sign in to comment.