Skip to content

Commit

Permalink
small change
Browse files Browse the repository at this point in the history
  • Loading branch information
EMCarrami committed Oct 30, 2023
1 parent 1f96901 commit b7ea086
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cprt/model/cprt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def training_step(self, batch: CprtData, batch_idx: int) -> Tensor:
out = self(protein_ids=batch.protein, info_ids=batch.info, attention_mask=batch.info_mask, labels=batch.labels)
loss: Tensor = out["loss"]
self.log("loss/train_loss", loss.item(), prog_bar=True)
self.train_perplexity.update(out["logits"].detach(), batch.labels)
self.train_perplexity.update(out["logits"].detach()[:, :-1], batch.labels[:, 1:])
if batch_idx % self.trainer.val_check_interval == 0 and batch_idx != 0:
self.log("metrics/train_perplexity", self.train_perplexity.compute(), on_step=True)
self.train_perplexity.reset()
Expand All @@ -129,7 +129,7 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None:
)
loss: Tensor = out["loss"]
self.log("loss/val_loss", loss.item(), prog_bar=True)
self.val_perplexity.update(out["logits"], batch.labels)
self.val_perplexity.update(out["logits"][:, :-1], batch.labels[:, 1:])
input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True)
generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True)
self.val_rouge_scores.update(generated_text, input_text)
Expand Down

0 comments on commit b7ea086

Please sign in to comment.