Skip to content

Commit

Permalink
forgot the dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 5, 2024
1 parent 67dbd21 commit 4063dca
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 5 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ release: ## Create a new tag for release.
@echo "creating git tag : $${TAG}"
@git tag $${TAG}
@git push -u origin HEAD --tags
@
@echo "Github Actions will detect the new tag and release the new version."
@mkdocs gh-deploy
@echo "docs published too"

.PHONY: docs
docs: ## Build the documentation.
Expand Down
4 changes: 4 additions & 0 deletions docs/dataloader.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Documentation for `DataLoader`

::: scdataloader.dataloader.DataLoader
handler: python
1 change: 0 additions & 1 deletion docs/notebooks/1_download_and_preprocess.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"\n",
"import lamindb as ln\n",
"import lnschema_bionty as lb\n",
"import pandas as pd\n",
"\n",
"lb.settings.organism = \"human\"\n",
"\n",
Expand Down
13 changes: 9 additions & 4 deletions docs/notebooks/2_create_dataloader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
"outputs": [],
"source": [
"from scdataloader import Dataset\n",
"from scdataloader import dataLoader\n",
"from scdataloader import DataLoader\n",
"import pandas as pd\n",
"import lamindb as ln\n",
"import lnschema_bionty as lb\n",
"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
Expand Down Expand Up @@ -128,6 +132,7 @@
"metadata": {},
"outputs": [],
"source": [
"# see scprint for this or contact me (@jkobject)\n",
"embeddings = embed(genedf=genedf,\n",
" organism=\"homo_sapiens\",\n",
" cache=True,\n",
Expand Down Expand Up @@ -189,7 +194,7 @@
"source": [
"# the dataloader can weight some rare samples more: \n",
"# one need to provide the labels on which to weight the samples:\n",
"labels_weighted_sampling = hierarchical_labels+[\n",
"labels_weighted_sampling = [\n",
" 'sex_ontology_term_id',\n",
" \"cell_type_ontology_term_id\",\n",
" #\"tissue_ontology_term_id\",\n",
Expand Down Expand Up @@ -246,7 +251,7 @@
"source": [
"#we then create a mapped dataset. This transforms a bunch of anndata from possibly various species, into a combined object that acts roughly as a single anndata dataset \n",
"# (WIP to get all the features of an anndata object) \n",
"mdataset = Dataset(dataset, genedf, gene_embedding=embeddings, organisms=['\"NCBITaxon:9606\"'], obs=all_labels, encode_obs=labels_weighted_sampling, map_hierarchy=hierarchical_labels, )\n",
"mdataset = Dataset(dataset, genedf, gene_embedding=embeddings, organisms=[\"NCBITaxon:9606\"], obs=all_labels, encode_obs=labels_weighted_sampling)\n",
"mdataset"
]
},
Expand All @@ -257,7 +262,7 @@
"outputs": [],
"source": [
"# now we make the dataloader\n",
"dataloader = BaseDataLoader(mdataset, label_to_weight=labels_weighted_sampling, batch_size=4, num_workers=1)\n",
"dataloader = DataLoader(mdataset, label_to_weight=labels_weighted_sampling, batch_size=4, num_workers=1)\n",
"len(dataloader)"
]
},
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- dataset: dataset.md
- preprocess: preprocess.md
- utils: utils.md
- dataloader: dataloader.md
plugins:
- search
- mkdocstrings:
Expand Down
3 changes: 3 additions & 0 deletions scdataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data import Dataset
from .dataloader import DataLoader
from .preprocess import Preprocessor
204 changes: 204 additions & 0 deletions scdataloader/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import numpy as np
from torch.utils.data import DataLoader as TorchLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import WeightedRandomSampler
from scdataloader.mapped import MappedDataset
from typing import Union
import torch

# TODO: put in config
COARSE_TISSUE = {
"adipose tissue": "",
"bladder organ": "",
"blood": "",
"bone marrow": "",
"brain": "",
"breast": "",
"esophagus": "",
"eye": "",
"embryo": "",
"fallopian tube": "",
"gall bladder": "",
"heart": "",
"intestine": "",
"kidney": "",
"liver": "",
"lung": "",
"lymph node": "",
"musculature of body": "",
"nose": "",
"ovary": "",
"pancreas": "",
"placenta": "",
"skin of body": "",
"spinal cord": "",
"spleen": "",
"stomach": "",
"thymus": "",
"thyroid gland": "",
"tongue": "",
"uterus": "",
}

COARSE_ANCESTRY = {
"African": "",
"Chinese": "",
"East Asian": "",
"Eskimo": "",
"European": "",
"Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
"Hispanic or Latin American": "",
"Native American": "",
"Oceanian": "",
"South Asian": "",
}

COARSE_DEVELOPMENT_STAGE = {
"Embryonic human": "",
"Fetal": "",
"Immature": "",
"Mature": "",
}

COARSE_ASSAY = {
"10x 3'": "",
"10x 5'": "",
"10x multiome": "",
"CEL-seq2": "",
"Drop-seq": "",
"GEXSCOPE technology": "",
"inDrop": "",
"microwell-seq": "",
"sci-Plex": "",
"sci-RNA-seq": "",
"Seq-Well": "",
"Slide-seq": "",
"Smart-seq": "",
"SPLiT-seq": "",
"TruDrop": "",
"Visium Spatial Gene Expression": "",
}


class DataLoader(TorchLoader):
"""
Base class for all data loaders
"""

def __init__(
self,
mapped_dataset: MappedDataset,
batch_size: int = 32,
weight_scaler: int = 30,
label_to_weight: list = [],
validation_split: float = 0.2,
num_workers: int = 4,
collate_fn=default_collate,
sampler=None,
**kwargs,
):
self.validation_split = validation_split
self.dataset = mapped_dataset

self.batch_idx = 0
self.batch_size = batch_size
self.n_samples = len(self.dataset)
if sampler is None:
self.sampler, self.valid_sampler = self._split_sampler(
self.validation_split,
weight_scaler=weight_scaler,
label_to_weight=label_to_weight,
)
else:
self.sampler = sampler
self.valid_sampler = None

self.init_kwargs = {
"dataset": self.dataset,
"batch_size": batch_size,
"collate_fn": collate_fn,
"num_workers": num_workers,
}
super().__init__(sampler=self.sampler, **self.init_kwargs, **kwargs)

def _split_sampler(self, split, label_to_weight=[], weight_scaler: int = 30):
idx_full = np.arange(self.n_samples)
np.random.shuffle(idx_full)
if len(label_to_weight) > 0:
weights = self.dataset.get_label_weights(
label_to_weight, scaler=weight_scaler
)
else:
weights = np.ones(self.n_samples)
if isinstance(split, int):
assert (
split < self.n_samples
), "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)
if len_valid == 0:
self.train_idx = idx_full
else:
self.valid_idx = idx_full[0:len_valid]
self.train_idx = np.delete(idx_full, np.arange(0, len_valid))
valid_weights = weights.copy()
valid_weights[self.train_idx] = 0
# TODO: should we do weighted random sampling for validation set?
valid_sampler = WeightedRandomSampler(
valid_weights, len_valid, replacement=True
)
train_weights = weights.copy()
train_weights[self.valid_idx] = 0
train_sampler = WeightedRandomSampler(
train_weights, len(self.train_idx), replacement=True
)
# turn off shuffle option which is mutually exclusive with sampler

return (
(train_sampler, valid_sampler) if len_valid != 0 else (train_sampler, None)
)

def get_valid_dataloader(self):
if self.valid_sampler is None:
raise ValueError("No validation set is configured.")
return DataLoader(
self.dataset, batch_size=self.batch_size, sampler=self.valid_sampler
)


def weighted_random_mask_value(
values: Union[torch.Tensor, np.ndarray],
mask_ratio: float = 0.15,
mask_value: int = -1,
important_elements: Union[torch.Tensor, np.ndarray] = np.array([]),
important_weight: int = 0,
pad_value: int = 0,
) -> torch.Tensor:
"""
Randomly mask a batch of data.
Args:
values (array-like):
A batch of tokenized data, with shape (batch_size, n_features).
mask_ratio (float): The ratio of genes to mask, default to 0.15.
mask_value (int): The value to mask with, default to -1.
pad_value (int): The value of padding in the values, will be kept unchanged.
Returns:
torch.Tensor: A tensor of masked data.
"""
if isinstance(values, torch.Tensor):
# it is crutial to clone the tensor, otherwise it changes the original tensor
values = values.clone().detach().numpy()
else:
values = values.copy()

for i in range(len(values)):
row = values[i]
non_padding_idx = np.nonzero(row - pad_value)[0]
non_padding_idx = np.setdiff1d(non_padding_idx, do_not_pad_index)
n_mask = int(len(non_padding_idx) * mask_ratio)
mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False)
row[mask_idx] = mask_value
return torch.from_numpy(values).float()

0 comments on commit 4063dca

Please sign in to comment.