Skip to content

Commit

Permalink
Safeguard Chai MSA construction
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Dec 17, 2024
1 parent 3acdd4f commit b07d072
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
1 change: 1 addition & 0 deletions configs/data/components/prepare_chai_msas.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dataset: posebusters_benchmark # the dataset to use - NOTE: must be one of (`posebusters_benchmark`, `astex_diverse`, `dockgen`, `casp15`)
input_msa_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_msas # where the original MSA files are placed
output_msa_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_chai_msas # where the processed MSA files should be stored
skip_existing: True # whether to skip processing if the output file already exists
57 changes: 32 additions & 25 deletions posebench/data/components/prepare_chai_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,40 @@ def main(cfg: DictConfig):
if not msa_file.endswith(".npz"):
continue

input_msa_path = os.path.join(cfg.input_msa_dir, msa_file)
input_msa = dict(np.load(input_msa_path))

item = msa_file.split("_protein")[0].split("_lig")[0]
input_msa_path = os.path.join(cfg.input_msa_dir, msa_file)

for chain_index in range(input_msa["n"]):
output_msas = [
{
"sequence": "".join(ID_TO_HHBLITS_AA[c] for c in seq),
"source_database": "query" if seq_index == 0 else "uniref90",
"pairing_key": f"sequence:{seq_index}"
if input_msa[f"is_paired_{chain_index}"][seq_index].item() is True
else "",
"comment": "",
}
for seq_index, seq in enumerate(input_msa[f"msa_{chain_index}"])
]
output_msa_df = pd.DataFrame(output_msas)

output_msa_path = os.path.join(
cfg.output_msa_dir, item + f"_chain_{chain_index}.aligned.pqt"
)

logger.info(
f"Converting chain MSA to DataFrame: {input_msa_path} -> {output_msa_path}"
)
output_msa_df.to_parquet(output_msa_path)
try:
input_msa = dict(np.load(input_msa_path))

for chain_index in range(input_msa["n"]):
output_msa_path = os.path.join(
cfg.output_msa_dir, item + f"_chain_{chain_index}.aligned.pqt"
)
if os.path.exists(output_msa_path) and cfg.skip_existing:
logger.info(f"MSA already exists: {output_msa_path}. Skipping...")
continue

output_msas = [
{
"sequence": "".join(ID_TO_HHBLITS_AA[c] for c in seq),
"source_database": "query" if seq_index == 0 else "uniref90",
"pairing_key": f"sequence:{seq_index}"
if input_msa[f"is_paired_{chain_index}"][seq_index].item() is True
else "",
"comment": "",
}
for seq_index, seq in enumerate(input_msa[f"msa_{chain_index}"])
]
output_msa_df = pd.DataFrame(output_msas)

logger.info(
f"Converting chain MSA to DataFrame: {input_msa_path} -> {output_msa_path}"
)
output_msa_df.to_parquet(output_msa_path)

except Exception as e:
logger.error(f"Failed to process MSA {input_msa_path} due to: {e}. Skipping...")


if __name__ == "__main__":
Expand Down

0 comments on commit b07d072

Please sign in to comment.