Skip to content

Commit

Permalink
fix loss and add tables
Browse files Browse the repository at this point in the history
  • Loading branch information
EMCarrami committed Oct 30, 2023
1 parent 93ccde6 commit c2ffc14
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
6 changes: 3 additions & 3 deletions cprt/data/cprt_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def collate_fn(self, batch: List[Tuple[str, str]]) -> CprtData:
truncation=True,
max_length=1024,
)
labels = tokenized_info["input_ids"][:, 1:].contiguous()
labels = tokenized_info["input_ids"].clone()
labels[:, : self.placeholder_length] = -100
for i, pad_idx in enumerate((1 - tokenized_info["attention_mask"]).sum(1)):
if pad_idx > 0:
labels[i, -pad_idx:] = -100
return CprtData(
info=tokenized_info["input_ids"][:, :-1].contiguous(),
info_mask=tokenized_info["attention_mask"][:, :-1].contiguous(),
info=tokenized_info["input_ids"],
info_mask=tokenized_info["attention_mask"],
protein=self.protein_tokenizer(protein_sequences, padding=True, return_tensors="pt")["input_ids"],
labels=labels,
)
17 changes: 11 additions & 6 deletions cprt/model/cprt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

import wandb
from cprt.data.cprt_datamodule import CprtData
from cprt.model.helper_modules import CrossAttentionDecoderLayer, TruncatedESM2

Expand Down Expand Up @@ -35,12 +36,14 @@ def __init__(
self._add_cross_attention_to_llm()
self._modify_generation_input_to_llm()

self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

self.train_perplexity = Perplexity(ignore_index=-100)
self.val_perplexity = Perplexity(ignore_index=-100)
self.val_rouge_scores = ROUGEScore()

self.text_table = wandb.Table( # type: ignore[no-untyped-call]
columns=["epoch", "input_text", "generated_text"]
)

def _add_cross_attention_to_llm(self) -> None:
"""Add Cross-Attention layers to all decoder blocks."""
protein_emb_size = cast(int, self.esm.embed_tokens.embedding_dim)
Expand Down Expand Up @@ -106,8 +109,8 @@ def forward(

def training_step(self, batch: CprtData, batch_idx: int) -> Tensor:
"""Take a train step."""
out = self(protein_ids=batch.protein, info_ids=batch.info, attention_mask=batch.info_mask)
loss: Tensor = self.criterion(out["logits"].transpose(1, 2), batch.labels)
out = self(protein_ids=batch.protein, info_ids=batch.info, attention_mask=batch.info_mask, labels=batch.labels)
loss: Tensor = out["loss"]
self.log("loss/train_loss", loss.item(), prog_bar=True)
self.train_perplexity.update(out["logits"].detach(), batch.labels)
if batch_idx % self.trainer.val_check_interval == 0 and batch_idx != 0:
Expand All @@ -124,14 +127,16 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None:
attention_mask=batch.info_mask,
labels=batch.labels,
)
loss: Tensor = self.criterion(out["logits"].transpose(1, 2), batch.labels)
loss: Tensor = out["loss"]
self.log("loss/val_loss", loss.item(), prog_bar=True)
self.log("loss/val_gpt2_loss", out["loss"].item(), prog_bar=True)
self.val_perplexity.update(out["logits"], batch.labels)
input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True)
generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True)
self.val_rouge_scores.update(generated_text, input_text)
if batch_idx == 0:
for i, g in zip(input_text, generated_text):
self.text_table.add_data(self.current_epoch, i, g) # type: ignore[no-untyped-call]
wandb.log({"val_generation": self.text_table})
print(generated_text)
torch.cuda.empty_cache()

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
gdown
nltk
pandas
lightning~=2.0
Expand Down

0 comments on commit c2ffc14

Please sign in to comment.