diff --git a/cprt/model/cprt_model.py b/cprt/model/cprt_model.py index 9e0ee23..91a44b8 100644 --- a/cprt/model/cprt_model.py +++ b/cprt/model/cprt_model.py @@ -112,7 +112,7 @@ def training_step(self, batch: CprtData, batch_idx: int) -> Tensor: out = self(protein_ids=batch.protein, info_ids=batch.info, attention_mask=batch.info_mask, labels=batch.labels) loss: Tensor = out["loss"] self.log("loss/train_loss", loss.item(), prog_bar=True) - self.train_perplexity.update(out["logits"].detach(), batch.labels) + self.train_perplexity.update(out["logits"].detach()[:, :-1], batch.labels[:, 1:]) if batch_idx % self.trainer.val_check_interval == 0 and batch_idx != 0: self.log("metrics/train_perplexity", self.train_perplexity.compute(), on_step=True) self.train_perplexity.reset() @@ -129,7 +129,7 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None: ) loss: Tensor = out["loss"] self.log("loss/val_loss", loss.item(), prog_bar=True) - self.val_perplexity.update(out["logits"], batch.labels) + self.val_perplexity.update(out["logits"][:, :-1], batch.labels[:, 1:]) input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True) generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True) self.val_rouge_scores.update(generated_text, input_text)