Skip to content

Commit

Permalink
Wire in custom MSA support for Chai-1
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Dec 17, 2024
1 parent 1bd4732 commit 124227b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions configs/model/chai_inference.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
dataset: posebusters_benchmark # the dataset to use - NOTE: must be one of (`posebusters_benchmark`, `astex_diverse`, `dockgen`, `casp15`)
input_dir: ${oc.env:PROJECT_ROOT}/forks/chai-lab/prediction_inputs/${dataset} # the input directory with which to run inference
output_dir: ${oc.env:PROJECT_ROOT}/forks/chai-lab/prediction_outputs/${dataset}_${repeat_index} # the output directory to which to save the inference results
msa_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_chai_msas # the directory containing the `.aligned.pqt` MSA files prepared for Chai-1 via `posebench/data/components/prepare_chai_msas.py`; if not provided, Chai-1 will be run in single-sequence mode
cuda_device_index: 0 # the CUDA device to use for inference, or `null` to use CPU
repeat_index: 1 # the repeat index to use for inference
skip_existing: true # whether to skip running inference if the prediction for a target already exists
Expand Down
11 changes: 8 additions & 3 deletions forks/chai-lab/chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def get_msa_contexts(
pdb_ids = set(chain.entity_data.pdb_id for chain in chains)
assert len(pdb_ids) == 1, f"Found >1 pdb ids in chains: {pdb_ids=}"

pdb_id = pdb_ids.pop()

# MSAs are constructed based on sequence, so use the unique sequences present
# in input chains to determine the MSAs that need to be loaded

def get_msa_contexts_for_seq(seq) -> MSAContext:
def get_msa_contexts_for_seq(seq, chain_index) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
# Try parsing custom chain MSA file
path = msa_directory / f"{pdb_id}_chain_{chain_index}.aligned.pqt"
if not path.is_file():
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
Expand All @@ -59,10 +64,10 @@ def get_msa_contexts_for_seq(seq) -> MSAContext:
# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary)
msa_contexts = [
get_msa_contexts_for_seq(chain.entity_data.sequence)[
get_msa_contexts_for_seq(chain.entity_data.sequence, chain_index)[
:, chain.structure_context.token_residue_index
]
for chain in chains
for chain_index, chain in enumerate(chains)
]

# used later only for profile statistics
Expand Down
1 change: 1 addition & 0 deletions posebench/models/chai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_chai_inference(fasta_file: str, cfg: DictConfig):
try:
run_inference(
fasta_file=Path(fasta_file),
msa_dir=Path(cfg.msa_dir) if cfg.msa_dir else None,
output_dir=Path(output_dir),
# 'default' setup
num_trunk_recycles=3,
Expand Down

0 comments on commit 124227b

Please sign in to comment.