Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EMCarrami committed Nov 11, 2023
1 parent 9eab60c commit 71fb64a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
7 changes: 4 additions & 3 deletions configs/train_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
"language_model": "gpt2-xl",
"protein_model": "esm2_t12_35M_UR50D",
"protein_layer_to_use": -1,
"perceiver_latent_size": 100,
"perceiver_latent_size": 20,
"num_perceiver_layers": 1,
"enable_gradient_checkpointing": false
},
"wandb": {
"project": "Cprt"
"project": "Cprt-Experiments"
},
"trainer": {
"precision": "16-mixed",
"val_check_interval": 500,
"max_epochs": 2,
"max_epochs": 1,
"devices": 1
}
}
54 changes: 28 additions & 26 deletions cprt/model/cprt_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import gc
from typing import Dict, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, cast

import torch
import wandb
from lightning import LightningModule
from lightning.pytorch.utilities import rank_zero_only
from torch import FloatTensor, LongTensor, Tensor, nn
from torchmetrics import MeanMetric
from torchmetrics.text import Perplexity, ROUGEScore
Expand Down Expand Up @@ -111,16 +112,9 @@ def forward(

def training_step(self, batch: CprtData, batch_idx: int) -> Tensor:
"""Take a train step."""
self.log("monitor/max_protein_length", batch.protein.size(1), prog_bar=True)
self.log("monitor/max_info_length", batch.info.size(1), prog_bar=True)

out = self(protein_ids=batch.protein, info_ids=batch.info, attention_mask=batch.info_mask, labels=batch.labels)
loss: Tensor = out["loss"]

self.train_mean_loss.update(loss.item())
if batch_idx % (self.trainer.val_check_interval // 10) == 0:
self.log("loss/train_loss", self.train_mean_loss.compute(), prog_bar=True)
self.train_mean_loss.reset()
self.log("loss/train_loss", loss.item(), prog_bar=True)

self.train_perplexity.update(out["logits"].detach()[:, :-1].float(), batch.labels[:, 1:])
if batch_idx % self.trainer.val_check_interval == 0 and batch_idx != 0:
Expand All @@ -145,22 +139,14 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None:
input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True)
generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True)
self.val_rouge_scores.update(generated_text, input_text)
self.last_val_batch = batch
if batch_idx == 0:
self.log_example_outputs(input_text, batch.protein)
torch.cuda.empty_cache()
gc.collect()

def on_validation_epoch_end(self) -> None:
self.log("metrics/val_perplexity", self.val_perplexity.compute())
self.val_perplexity.reset()
rouge_scores: Dict[str, Tensor] = self.val_rouge_scores.compute()
self.log_dict({f"metrics/val_{k}": v.mean() for k, v in rouge_scores.items()})
self.val_rouge_scores.reset()
for idx, layer in enumerate(self.cprt_llm.transformer.h):
self.log(f"gates/layer_{idx}_attn_gate", layer.cross_attn.attn_gate.item())
self.log(f"gates/layer_{idx}_ff_gate", layer.cross_attn.ff_gate.item())
# log example outputs
input_text = self.text_tokenizer.batch_decode(self.last_val_batch.info, skip_special_tokens=True)
for in_txt, protein in zip(input_text, self.last_val_batch.protein):
def log_example_outputs(self, input_text: List[str], protein: Tensor) -> None:
"""Log example generated responses."""
for in_txt, protein in zip(input_text, protein):
if "?" in in_txt:
question = in_txt.split("?")[0]
preds = self.cprt_llm.generate(
Expand All @@ -172,11 +158,27 @@ def on_validation_epoch_end(self) -> None:
response = self.text_tokenizer.decode(preds[0].cpu())
self.text_table.add_data(self.trainer.global_step, in_txt, response) # type: ignore[no-untyped-call]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure optimizer."""
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-2)
return optimizer
def on_validation_epoch_end(self) -> None:
"""Log validation metrics."""
self.log("metrics/val_perplexity", self.val_perplexity.compute())
self.val_perplexity.reset()
rouge_scores: Dict[str, Tensor] = self.val_rouge_scores.compute()
self.log_dict({f"metrics/val_{k}": v.mean() for k, v in rouge_scores.items()})
self.val_rouge_scores.reset()
for idx, layer in enumerate(self.cprt_llm.transformer.h):
self.log(f"gates/layer_{idx}_attn_gate", layer.cross_attn.attn_gate.item())
self.log(f"gates/layer_{idx}_ff_gate", layer.cross_attn.ff_gate.item())

def on_fit_end(self) -> None:
"""Log generation examples table."""
self.log_wandb_table()

@rank_zero_only # type: ignore[misc]
def log_wandb_table(self) -> None:
"""Log wandb table of example outputs."""
wandb.log({"val_generation": self.text_table})

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure optimizer."""
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-2)
return optimizer
14 changes: 7 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ def train_cprt(config: Dict[str, Any], log_to_wandb: bool = False) -> None:

datamodule = creat_datamodule(**config["data"], datamodule_config=config["datamodule"], only_keep_questions=True)
model = Cprt(**config["model"])

group_name = f"{config['model']['protein_model']}_{config['model']['language_model']}"
checkpoint_callback = ModelCheckpoint(
monitor="loss/val_loss",
mode="min",
save_top_k=1,
dirpath=f"{ROOT}/model_checkpoints/{config['model']['protein_model']}_{config['model']['language_model']}",
dirpath=f"{ROOT}/model_checkpoints/{group_name}_{config.get('seed', 'random')}",
verbose=True,
)

config["trainer"]["log_every_n_steps"] = 1
if log_to_wandb:
wandb.init(**config["wandb"])
trainer = Trainer(logger=WandbLogger(), callbacks=[checkpoint_callback], **config["trainer"])
config["wandb"]["group"] = group_name
wandb_logger = WandbLogger(**config["wandb"])
trainer = Trainer(logger=wandb_logger, callbacks=checkpoint_callback, **config["trainer"])
trainer.logger.log_hyperparams(config)
try:
trainer.fit(model, datamodule)
except Exception as e:
print(f"An error occurred: {e}")
model.log_wandb_table()
finally:
model.log_wandb_table()
wandb.finish()
else:
trainer = Trainer(callbacks=[checkpoint_callback], **config["trainer"])
trainer = Trainer(callbacks=checkpoint_callback, **config["trainer"])
trainer.fit(model, datamodule)


Expand Down

0 comments on commit 71fb64a

Please sign in to comment.