From 0dba2495642d0b55df80e025bda039b7af6ca80a Mon Sep 17 00:00:00 2001 From: jkobject Date: Wed, 22 Jan 2025 15:28:42 +0100 Subject: [PATCH] allowing it to work in more contexts --- scdataloader/collator.py | 8 ++++---- scdataloader/preprocess.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/scdataloader/collator.py b/scdataloader/collator.py index 51c9cec..e18a1ea 100644 --- a/scdataloader/collator.py +++ b/scdataloader/collator.py @@ -23,7 +23,7 @@ def __init__( class_names: list[str] = [], genelist: list[str] = [], downsample: Optional[float] = None, # don't use it for training! - save_output: bool = False, + save_output: Optional[str] = None, ): """ This class is responsible for collating data for the scPRINT model. It handles the @@ -59,7 +59,7 @@ def __init__( If [] all genes will be considered downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None. This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator - save_output (bool, optional): If True, saves the output to a file. Defaults to False. + save_output (str, optional): If not None, saves the output to a file. Defaults to None. This is mainly for debugging purposes """ self.organisms = organisms @@ -237,8 +237,8 @@ def __call__(self, batch) -> dict[str, Tensor]: ret.update({"dataset": Tensor(dataset).to(long)}) if self.downsample is not None: ret["x"] = downsample_profile(ret["x"], self.downsample) - if self.save_output: - with open("collator_output.txt", "a") as f: + if self.save_output is not None: + with open(self.save_output, "a") as f: np.savetxt(f, ret["x"].numpy()) return ret diff --git a/scdataloader/preprocess.py b/scdataloader/preprocess.py index a84ebf4..5bfa50b 100644 --- a/scdataloader/preprocess.py +++ b/scdataloader/preprocess.py @@ -59,6 +59,7 @@ def __init__( do_postp: bool = True, organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"], use_raw: bool = True, + keepdata: bool = False, ) -> None: """ Initializes the preprocessor and configures the workflow steps. @@ -99,6 +100,8 @@ def __init__( This arg is used in the highly variable gene selection step. skip_validate (bool, optional): Determines whether to skip the validation step. Defaults to False. + keepdata (bool, optional): Determines whether to keep the data in the AnnData object. + Defaults to False. """ self.filter_gene_by_counts = filter_gene_by_counts self.filter_cell_by_counts = filter_cell_by_counts @@ -124,6 +127,7 @@ def __init__( self.is_symbol = is_symbol self.do_postp = do_postp self.use_raw = use_raw + self.keepdata = keepdata def __call__(self, adata, dataset_id=None) -> AnnData: if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms: @@ -144,12 +148,13 @@ def __call__(self, adata, dataset_id=None) -> AnnData: print("X was not raw counts, using 'counts' layer") adata.X = adata.layers["counts"].copy() print("Dropping layers: ", adata.layers.keys()) - del adata.layers - if len(adata.varm.keys()) > 0: + if not self.keepdata: + del adata.layers + if len(adata.varm.keys()) > 0 and not self.keepdata: del adata.varm - if len(adata.obsm.keys()) > 0 and self.do_postp: + if len(adata.obsm.keys()) > 0 and self.do_postp and not self.keepdata: del adata.obsm - if len(adata.obsp.keys()) > 0 and self.do_postp: + if len(adata.obsp.keys()) > 0 and self.do_postp and not self.keepdata: del adata.obsp # check that it is a count print("checking raw counts") @@ -478,7 +483,7 @@ def __call__( try: if file.size > MAXFILESIZE: print( - f"dividing the dataset as it is too large: {file.size//1_000_000_000}Gb" + f"dividing the dataset as it is too large: {file.size // 1_000_000_000}Gb" ) num_blocks = int(np.ceil(file.size / (MAXFILESIZE / 2))) block_size = int(