diff --git a/cprt/model/cprt_model.py b/cprt/model/cprt_model.py index f02ef8e..bc30b89 100644 --- a/cprt/model/cprt_model.py +++ b/cprt/model/cprt_model.py @@ -3,7 +3,7 @@ import torch from lightning import LightningModule from torch import FloatTensor, LongTensor, Tensor, nn -from torchmetrics.text import BERTScore, BLEUScore, Perplexity +from torchmetrics.text import Perplexity, ROUGEScore from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions @@ -39,8 +39,7 @@ def __init__( self.train_perplexity = Perplexity(ignore_index=-100) self.val_perplexity = Perplexity(ignore_index=-100) - self.val_bleu_score = BLEUScore() - self.val_bert_scores = BERTScore() + self.val_rouge_scores = ROUGEScore() def _add_cross_attention_to_llm(self) -> None: """Add Cross-Attention layers to all decoder blocks.""" @@ -131,18 +130,17 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None: self.val_perplexity.update(out["logits"], batch.labels) input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True) generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True) - self.val_bleu_score.update(generated_text, input_text) - # self.val_bert_scores.update(generated_text, input_text) + self.val_rouge_scores.update(generated_text, input_text) + if batch_idx == 0: + print(generated_text) torch.cuda.empty_cache() def on_validation_epoch_end(self) -> None: self.log("metrics/val_perplexity", self.val_perplexity.compute()) self.val_perplexity.reset() - self.log("metrics/val_bleu", self.val_bleu_score.compute()) - self.val_bleu_score.reset() - # bert_score: Dict[str, Tensor] = self.val_bert_scores.compute() # type: ignore[assignment] - # self.log_dict({f"metrics/val_bert_{k}": v.mean() for k, v in bert_score.items()}) - # self.val_bert_scores.reset() + rouge_scores: Dict[str, Tensor] = self.val_rouge_scores.compute() + self.log_dict({f"metrics/val_{k}": v.mean() for k, v in rouge_scores.items()}) + self.val_bert_scores.reset() for idx, layer in enumerate(self.cprt_llm.transformer.h): self.log(f"gates/layer_{idx}_attn_gate", layer.attn_gate.item()) self.log(f"gates/layer_{idx}_ff_gate", layer.ff_gate.item())