From 5a7471afb4bb47c1fbf7a138ae18fd99f8e6d884 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 21 Jan 2025 13:19:28 +0100 Subject: [PATCH] Retry with single core in geneformer --- src/methods/geneformer/script.py | 50 ++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/methods/geneformer/script.py b/src/methods/geneformer/script.py index 521b8f5c..902a7735 100644 --- a/src/methods/geneformer/script.py +++ b/src/methods/geneformer/script.py @@ -23,22 +23,24 @@ print(">>> Reading input...", flush=True) sys.path.append(meta["resources_dir"]) -from read_anndata_partial import read_anndata from exit_codes import exit_non_applicable +from read_anndata_partial import read_anndata adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") if adata.uns["dataset_organism"] != "homo_sapiens": exit_non_applicable( f"Geneformer can only be used with human data " - f"(dataset_organism == \"{adata.uns['dataset_organism']}\")" + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' ) # Set adata.var_names to gene IDs adata.var_names = adata.var["feature_id"] is_ensembl = all(var_name.startswith("ENSG") for var_name in adata.var_names) if not is_ensembl: - raise ValueError(f"Geneformer requires adata.var_names to contain ENSEMBL gene ids") + exit_non_applicable( + "Geneformer requires adata.var_names to contain ENSEMBL gene ids" + ) print(f">>> Getting settings for model '{par['model']}'...", flush=True) model_split = par["model"].split("-") @@ -97,18 +99,42 @@ adata.write_h5ad(os.path.join(input_dir, "input.h5ad")) print(adata) + +# Function to try parallel execution and fall batch to a single processor if it fails +def tryParallelFunction(fun, label): + try: + fun(nproc=n_processors) + except RuntimeError as e: + # Retry with nproc=1 if error message contains "One of the subprocesses has abruptly died" + if "subprocess" in str(e) and "died" in str(e): + print(f"{label} failed. Error message: {e}", flush=True) + print("Retrying with nproc=1", flush=True) + fun(nproc=1) + else: + raise e + + print(">>> Tokenizing data...", flush=True) special_token = model_details["dataset"] == "95M" print(f"Input size: {model_details['input_size']}, Special token: {special_token}") -tokenizer = TranscriptomeTokenizer( - nproc=n_processors, - model_input_size=model_details["input_size"], - special_token=special_token, - gene_median_file=dictionary_files["gene_median"], - token_dictionary_file=dictionary_files["token"], - gene_mapping_file=dictionary_files["ensembl_mapping"], -) -tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad") + + +def tokenize_data(nproc): + tokenizer = TranscriptomeTokenizer( + nproc=nproc, + model_input_size=model_details["input_size"], + special_token=special_token, + gene_median_file=dictionary_files["gene_median"], + token_dictionary_file=dictionary_files["token"], + gene_mapping_file=dictionary_files["ensembl_mapping"], + ) + + tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad") + + return tokenizer + + +tokenizer = tryParallelFunction(tokenize_data, "Tokenizing data") print(f">>> Getting model files for model '{par['model']}'...", flush=True) model_files = {