Skip to content

Commit

Permalink
Add arguments to scPRINT and enable tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lazappi committed Jan 21, 2025
1 parent 642b31e commit f45bd88
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
15 changes: 14 additions & 1 deletion src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__merge__: /src/api/base_method.yaml
__merge__: /src/api/comp_method.yaml

name: scprint
label: scPRINT
Expand Down Expand Up @@ -38,6 +38,11 @@ info:
model_name: "medium"
scprint_small:
model_name: "small"
test_setup:
run:
model_name: small
batch_size: 16
max_len: 100

arguments:
- name: "--model_name"
Expand All @@ -49,6 +54,14 @@ arguments:
type: file
description: Path to the scPRINT model.
required: false
- name: --batch_size
type: integer
description: The size of the batches to be used in the DataLoader.
default: 64
- name: --max_len
type: integer
description: The maximum length of the gene sequence.
default: 4000

resources:
- type: python_script
Expand Down
20 changes: 11 additions & 9 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import anndata as ad
from scdataloader import Preprocessor
import os
import sys
from huggingface_hub import hf_hub_download
from scprint.tasks import Embedder
from scprint import scPrint

import anndata as ad
import scprint
import torch
import os
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scprint import scPrint
from scprint.tasks import Embedder

## VIASH START
par = {
Expand All @@ -19,8 +20,8 @@
## VIASH END

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata
from exit_codes import exit_non_applicable
from read_anndata_partial import read_anndata

print(f"====== scPRINT version {scprint.__version__} ======", flush=True)

Expand All @@ -41,7 +42,7 @@

print("\n>>> Preprocessing data...", flush=True)
preprocessor = Preprocessor(
min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000
min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000
# Turn off cell filtering to return results for all cells
filter_cell_by_counts=False,
min_nnz_genes=False,
Expand Down Expand Up @@ -77,7 +78,8 @@
print(f"Using {n_cores_available} worker cores")
embedder = Embedder(
how="random expr",
max_len=4000,
batch_size=par["batch_size"],
max_len=par["max_len"],
add_zero_genes=0,
num_workers=n_cores_available,
doclass=False,
Expand Down

0 comments on commit f45bd88

Please sign in to comment.