+
+ +

Documentation for Collator

+ + +
+ + + + +

+ scdataloader.collator.Collator + +

+ + +
+ + +

This class is responsible for collating data for the scPRINT model. It handles the +organization and preparation of gene expression data from different organisms, +allowing for various configurations such as maximum gene list length, normalization, +and selection method for gene expression.

+

This Collator should work with scVI's dataloader as well!

+ + + + + + + + + + + + + + +
Parameters: +
    +
  • + organisms + (list) + – +
    +

    List of organisms to be considered for gene expression data. +it will drop any other organism it sees (might lead to batches of different sizes!)

    +
    +
  • +
  • + how + (flag, default: + 'all' +) + – +
    +

    Method for selecting gene expression. Defaults to "most expr". +one of ["most expr", "random expr", "all", "some"]: +"most expr": selects the max_len most expressed genes, +if less genes are expressed, will sample random unexpressed genes, +"random expr": uses a random set of max_len expressed genes. +if less genes are expressed, will sample random unexpressed genes +"all": uses all genes +"some": uses only the genes provided through the genelist param

    +
    +
  • +
  • + org_to_id + (dict, default: + None +) + – +
    +

    Dictionary mapping organisms to their respective IDs.

    +
    +
  • +
  • + valid_genes + (list, default: + [] +) + – +
    +

    List of genes from the datasets, to be considered. Defaults to []. +it will drop any other genes from the input expression data (usefull when your model only works on some genes)

    +
    +
  • +
  • + max_len + (int, default: + 2000 +) + – +
    +

    Maximum number of genes to use (for random expr and most expr). Defaults to 2000.

    +
    +
  • +
  • + n_bins + (int, default: + 0 +) + – +
    +

    Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.

    +
    +
  • +
  • + add_zero_genes + (int, default: + 0 +) + – +
    +

    Number of additional unexpressed genes to add to the input data. Defaults to 0.

    +
    +
  • +
  • + logp1 + (bool, default: + False +) + – +
    +

    If True, logp1 normalization is applied. Defaults to False.

    +
    +
  • +
  • + norm_to + (str, default: + None +) + – +
    +

    Normalization method to be applied. Defaults to None.

    +
    +
  • +
+
+
+ Source code in scdataloader/collator.py +
 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
def __init__(
+    self,
+    organisms: list,
+    how="all",
+    org_to_id: dict = None,
+    valid_genes: list = [],
+    max_len=2000,
+    add_zero_genes=0,
+    logp1=False,
+    norm_to=None,
+    n_bins=0,
+    tp_name=None,
+    organism_name="organism_ontology_term_id",
+    class_names=[],
+    genelist=[],
+):
+    """
+    This class is responsible for collating data for the scPRINT model. It handles the
+    organization and preparation of gene expression data from different organisms,
+    allowing for various configurations such as maximum gene list length, normalization,
+    and selection method for gene expression.
+
+    This Collator should work with scVI's dataloader as well!
+
+    Args:
+        organisms (list): List of organisms to be considered for gene expression data.
+            it will drop any other organism it sees (might lead to batches of different sizes!)
+        how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
+            one of ["most expr", "random expr", "all", "some"]:
+            "most expr": selects the max_len most expressed genes,
+            if less genes are expressed, will sample random unexpressed genes,
+            "random expr": uses a random set of max_len expressed genes.
+            if less genes are expressed, will sample random unexpressed genes
+            "all": uses all genes
+            "some": uses only the genes provided through the genelist param
+        org_to_id (dict): Dictionary mapping organisms to their respective IDs.
+        valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
+            it will drop any other genes from the input expression data (usefull when your model only works on some genes)
+        max_len (int, optional): Maximum number of genes to use (for random expr and most expr). Defaults to 2000.
+        n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
+        add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
+        logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
+        norm_to (str, optional): Normalization method to be applied. Defaults to None.
+    """
+    self.organisms = organisms
+    self.max_len = max_len
+    self.n_bins = n_bins
+    self.add_zero_genes = add_zero_genes
+    self.logp1 = logp1
+    self.norm_to = norm_to
+    self.org_to_id = org_to_id
+    self.how = how
+    self.organism_ids = (
+        set([org_to_id[k] for k in organisms])
+        if org_to_id is not None
+        else set(organisms)
+    )
+    if self.how == "some":
+        assert len(genelist) > 0, "if how is some, genelist must be provided"
+    self.organism_name = organism_name
+    self.tp_name = tp_name
+    self.class_names = class_names
+
+    self.start_idx = {}
+    self.accepted_genes = {}
+    self.genedf = load_genes(organisms)
+    self.to_subset = {}
+    for organism in set(self.genedf.organism):
+        ogenedf = self.genedf[self.genedf.organism == organism]
+        org = org_to_id[organism] if org_to_id is not None else organism
+        self.start_idx.update(
+            {org: np.where(self.genedf.organism == organism)[0][0]}
+        )
+        if len(valid_genes) > 0:
+            self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
+        if len(genelist) > 0:
+            df = ogenedf[ogenedf.index.isin(valid_genes)]
+            self.to_subset.update({org: df.index.isin(genelist)})
+
+
+ + + +
+ + + + + + + + + + +
+ + + + +

+ __call__ + +

+ + +
+ +

call is a special method in Python that is called when an instance of the class is called.

+ + + + + + + + + + + + + + +
Parameters: +
    +
  • + batch + (list[dict[str) + – +
    +

    array]]): List of dicts of arrays containing gene expression data. +the first list is for the different samples, the second list is for the different elements with +elem["x"]: gene expression +elem["organism_name"]: organism ontology term id +elem["tp_name"]: heat diff +elem["class_names.."]: other classes

    +
    +
  • +
+
+ + + + + + + + + + + + + +
Returns: +
    +
  • + – +
    +

    list[Tensor]: List of tensors containing the collated data.

    +
    +
  • +
+
+
+ Source code in scdataloader/collator.py +
 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
def __call__(self, batch):
+    """
+    __call__ is a special method in Python that is called when an instance of the class is called.
+
+    Args:
+        batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
+            the first list is for the different samples, the second list is for the different elements with
+            elem["x"]: gene expression
+            elem["organism_name"]: organism ontology term id
+            elem["tp_name"]: heat diff
+            elem["class_names.."]: other classes
+
+    Returns:
+        list[Tensor]: List of tensors containing the collated data.
+    """
+    # do count selection
+    # get the unseen info and don't add any unseen
+    # get the I most expressed genes, add randomly some unexpressed genes that are not unseen
+    exprs = []
+    total_count = []
+    other_classes = []
+    gene_locs = []
+    tp = []
+    nnz_loc = []
+    for elem in batch:
+        organism_id = elem[self.organism_name]
+        if organism_id not in self.organism_ids:
+            continue
+        expr = np.array(elem["x"])
+        total_count.append(expr.sum())
+        if len(self.accepted_genes) > 0:
+            expr = expr[self.accepted_genes[organism_id]]
+        if self.how == "most expr":
+            nnz_loc = np.where(expr > 0)[0]
+            ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
+            loc = np.argsort(expr)[-(ma):][::-1]
+            # nnz_loc = [1] * 30_000
+            # loc = np.argsort(expr)[-(self.max_len) :][::-1]
+        elif self.how == "random expr":
+            nnz_loc = np.where(expr > 0)[0]
+            loc = nnz_loc[
+                np.random.choice(
+                    len(nnz_loc),
+                    self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
+                    replace=False,
+                    # p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
+                )
+            ]
+        elif self.how in ["all", "some"]:
+            loc = np.arange(len(expr))
+        else:
+            raise ValueError("how must be either most expr or random expr")
+        if (
+            (self.add_zero_genes > 0) or (self.max_len > len(nnz_loc))
+        ) and self.how not in ["all", "some"]:
+            zero_loc = np.where(expr == 0)[0]
+            zero_loc = zero_loc[
+                np.random.choice(
+                    len(zero_loc),
+                    self.add_zero_genes
+                    + (
+                        0
+                        if self.max_len < len(nnz_loc)
+                        else self.max_len - len(nnz_loc)
+                    ),
+                    replace=False,
+                )
+            ]
+            loc = np.concatenate((loc, zero_loc), axis=None)
+        expr = expr[loc]
+        loc = loc + self.start_idx[organism_id]
+        if self.how == "some":
+            expr = expr[self.to_subset[organism_id]]
+            loc = loc[self.to_subset[organism_id]]
+        exprs.append(expr)
+        gene_locs.append(loc)
+
+        if self.tp_name is not None:
+            tp.append(elem[self.tp_name])
+        else:
+            tp.append(0)
+
+        other_classes.append([elem[i] for i in self.class_names])
+
+    expr = np.array(exprs)
+    tp = np.array(tp)
+    gene_locs = np.array(gene_locs)
+    total_count = np.array(total_count)
+    other_classes = np.array(other_classes)
+
+    # normalize counts
+    if self.norm_to is not None:
+        expr = (expr * self.norm_to) / total_count[:, None]
+    if self.logp1:
+        expr = np.log2(1 + expr)
+
+    # do binning of counts
+    if self.n_bins:
+        pass
+
+    # find the associated gene ids (given the species)
+
+    # get the NN cells
+
+    # do encoding / selection a la scGPT
+
+    # do encoding of graph location
+    # encode all the edges in some sparse way
+    # normalizing total counts between 0,1
+    return {
+        "x": Tensor(expr),
+        "genes": Tensor(gene_locs).int(),
+        "class": Tensor(other_classes).int(),
+        "tp": Tensor(tp),
+        "depth": Tensor(total_count),
+    }
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +

+ scdataloader.collator.AnnDataCollator + +

+ + +
+

+ Bases: Collator

+ + +

AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)

+ +
+ Source code in scdataloader/collator.py +
207
+208
+209
+210
+211
def __init__(self, *args, **kwargs):
+    """
+    AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
+    """
+    super().__init__(*args, **kwargs)
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+