Skip to content

Commit

Permalink
allowing it to work in more contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 22, 2025
1 parent 252e758 commit 0dba249
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
8 changes: 4 additions & 4 deletions scdataloader/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
class_names: list[str] = [],
genelist: list[str] = [],
downsample: Optional[float] = None, # don't use it for training!
save_output: bool = False,
save_output: Optional[str] = None,
):
"""
This class is responsible for collating data for the scPRINT model. It handles the
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
If [] all genes will be considered
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
save_output (bool, optional): If True, saves the output to a file. Defaults to False.
save_output (str, optional): If not None, saves the output to a file. Defaults to None.
This is mainly for debugging purposes
"""
self.organisms = organisms
Expand Down Expand Up @@ -237,8 +237,8 @@ def __call__(self, batch) -> dict[str, Tensor]:
ret.update({"dataset": Tensor(dataset).to(long)})
if self.downsample is not None:
ret["x"] = downsample_profile(ret["x"], self.downsample)
if self.save_output:
with open("collator_output.txt", "a") as f:
if self.save_output is not None:
with open(self.save_output, "a") as f:
np.savetxt(f, ret["x"].numpy())
return ret

Expand Down
15 changes: 10 additions & 5 deletions scdataloader/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
do_postp: bool = True,
organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
use_raw: bool = True,
keepdata: bool = False,
) -> None:
"""
Initializes the preprocessor and configures the workflow steps.
Expand Down Expand Up @@ -99,6 +100,8 @@ def __init__(
This arg is used in the highly variable gene selection step.
skip_validate (bool, optional): Determines whether to skip the validation step.
Defaults to False.
keepdata (bool, optional): Determines whether to keep the data in the AnnData object.
Defaults to False.
"""
self.filter_gene_by_counts = filter_gene_by_counts
self.filter_cell_by_counts = filter_cell_by_counts
Expand All @@ -124,6 +127,7 @@ def __init__(
self.is_symbol = is_symbol
self.do_postp = do_postp
self.use_raw = use_raw
self.keepdata = keepdata

def __call__(self, adata, dataset_id=None) -> AnnData:
if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
Expand All @@ -144,12 +148,13 @@ def __call__(self, adata, dataset_id=None) -> AnnData:
print("X was not raw counts, using 'counts' layer")
adata.X = adata.layers["counts"].copy()
print("Dropping layers: ", adata.layers.keys())
del adata.layers
if len(adata.varm.keys()) > 0:
if not self.keepdata:
del adata.layers
if len(adata.varm.keys()) > 0 and not self.keepdata:
del adata.varm
if len(adata.obsm.keys()) > 0 and self.do_postp:
if len(adata.obsm.keys()) > 0 and self.do_postp and not self.keepdata:
del adata.obsm
if len(adata.obsp.keys()) > 0 and self.do_postp:
if len(adata.obsp.keys()) > 0 and self.do_postp and not self.keepdata:
del adata.obsp
# check that it is a count
print("checking raw counts")
Expand Down Expand Up @@ -478,7 +483,7 @@ def __call__(
try:
if file.size > MAXFILESIZE:
print(
f"dividing the dataset as it is too large: {file.size//1_000_000_000}Gb"
f"dividing the dataset as it is too large: {file.size // 1_000_000_000}Gb"
)
num_blocks = int(np.ceil(file.size / (MAXFILESIZE / 2)))
block_size = int(
Expand Down

0 comments on commit 0dba249

Please sign in to comment.