diff --git a/configs/data/components/prepare_chai_msas.yaml b/configs/data/components/prepare_chai_msas.yaml index 417db06..6ad9aae 100644 --- a/configs/data/components/prepare_chai_msas.yaml +++ b/configs/data/components/prepare_chai_msas.yaml @@ -1,3 +1,4 @@ dataset: posebusters_benchmark # the dataset to use - NOTE: must be one of (`posebusters_benchmark`, `astex_diverse`, `dockgen`, `casp15`) input_msa_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_msas # where the original MSA files are placed output_msa_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_chai_msas # where the processed MSA files should be stored +skip_existing: True # whether to skip processing if the output file already exists diff --git a/posebench/data/components/prepare_chai_msas.py b/posebench/data/components/prepare_chai_msas.py index 07c1059..07e3f48 100644 --- a/posebench/data/components/prepare_chai_msas.py +++ b/posebench/data/components/prepare_chai_msas.py @@ -74,33 +74,40 @@ def main(cfg: DictConfig): if not msa_file.endswith(".npz"): continue - input_msa_path = os.path.join(cfg.input_msa_dir, msa_file) - input_msa = dict(np.load(input_msa_path)) - item = msa_file.split("_protein")[0].split("_lig")[0] + input_msa_path = os.path.join(cfg.input_msa_dir, msa_file) - for chain_index in range(input_msa["n"]): - output_msas = [ - { - "sequence": "".join(ID_TO_HHBLITS_AA[c] for c in seq), - "source_database": "query" if seq_index == 0 else "uniref90", - "pairing_key": f"sequence:{seq_index}" - if input_msa[f"is_paired_{chain_index}"][seq_index].item() is True - else "", - "comment": "", - } - for seq_index, seq in enumerate(input_msa[f"msa_{chain_index}"]) - ] - output_msa_df = pd.DataFrame(output_msas) - - output_msa_path = os.path.join( - cfg.output_msa_dir, item + f"_chain_{chain_index}.aligned.pqt" - ) - - logger.info( - f"Converting chain MSA to DataFrame: {input_msa_path} -> {output_msa_path}" - ) - output_msa_df.to_parquet(output_msa_path) + try: + input_msa = dict(np.load(input_msa_path)) + + for chain_index in range(input_msa["n"]): + output_msa_path = os.path.join( + cfg.output_msa_dir, item + f"_chain_{chain_index}.aligned.pqt" + ) + if os.path.exists(output_msa_path) and cfg.skip_existing: + logger.info(f"MSA already exists: {output_msa_path}. Skipping...") + continue + + output_msas = [ + { + "sequence": "".join(ID_TO_HHBLITS_AA[c] for c in seq), + "source_database": "query" if seq_index == 0 else "uniref90", + "pairing_key": f"sequence:{seq_index}" + if input_msa[f"is_paired_{chain_index}"][seq_index].item() is True + else "", + "comment": "", + } + for seq_index, seq in enumerate(input_msa[f"msa_{chain_index}"]) + ] + output_msa_df = pd.DataFrame(output_msas) + + logger.info( + f"Converting chain MSA to DataFrame: {input_msa_path} -> {output_msa_path}" + ) + output_msa_df.to_parquet(output_msa_path) + + except Exception as e: + logger.error(f"Failed to process MSA {input_msa_path} due to: {e}. Skipping...") if __name__ == "__main__":