From f45bd880266ae98a40266918d79c26c8ce41e204 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 21 Jan 2025 12:10:23 +0100 Subject: [PATCH] Add arguments to scPRINT and enable tests --- src/methods/scprint/config.vsh.yaml | 15 ++++++++++++++- src/methods/scprint/script.py | 20 +++++++++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 7cc88f51..1f2cad83 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -1,4 +1,4 @@ -__merge__: /src/api/base_method.yaml +__merge__: /src/api/comp_method.yaml name: scprint label: scPRINT @@ -38,6 +38,11 @@ info: model_name: "medium" scprint_small: model_name: "small" + test_setup: + run: + model_name: small + batch_size: 16 + max_len: 100 arguments: - name: "--model_name" @@ -49,6 +54,14 @@ arguments: type: file description: Path to the scPRINT model. required: false + - name: --batch_size + type: integer + description: The size of the batches to be used in the DataLoader. + default: 64 + - name: --max_len + type: integer + description: The maximum length of the gene sequence. + default: 4000 resources: - type: python_script diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index e76d6f39..5f0c95e8 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -1,12 +1,13 @@ -import anndata as ad -from scdataloader import Preprocessor +import os import sys -from huggingface_hub import hf_hub_download -from scprint.tasks import Embedder -from scprint import scPrint + +import anndata as ad import scprint import torch -import os +from huggingface_hub import hf_hub_download +from scdataloader import Preprocessor +from scprint import scPrint +from scprint.tasks import Embedder ## VIASH START par = { @@ -19,8 +20,8 @@ ## VIASH END sys.path.append(meta["resources_dir"]) -from read_anndata_partial import read_anndata from exit_codes import exit_non_applicable +from read_anndata_partial import read_anndata print(f"====== scPRINT version {scprint.__version__} ======", flush=True) @@ -41,7 +42,7 @@ print("\n>>> Preprocessing data...", flush=True) preprocessor = Preprocessor( - min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000 + min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000 # Turn off cell filtering to return results for all cells filter_cell_by_counts=False, min_nnz_genes=False, @@ -77,7 +78,8 @@ print(f"Using {n_cores_available} worker cores") embedder = Embedder( how="random expr", - max_len=4000, + batch_size=par["batch_size"], + max_len=par["max_len"], add_zero_genes=0, num_workers=n_cores_available, doclass=False,