Skip to content

Commit

Permalink
change metrics to rouge
Browse files Browse the repository at this point in the history
  • Loading branch information
EMCarrami committed Oct 30, 2023
1 parent 96d83a2 commit 984ac04
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions cprt/model/cprt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from lightning import LightningModule
from torch import FloatTensor, LongTensor, Tensor, nn
from torchmetrics.text import BERTScore, BLEUScore, Perplexity
from torchmetrics.text import Perplexity, ROUGEScore
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

Expand Down Expand Up @@ -39,8 +39,7 @@ def __init__(

self.train_perplexity = Perplexity(ignore_index=-100)
self.val_perplexity = Perplexity(ignore_index=-100)
self.val_bleu_score = BLEUScore()
self.val_bert_scores = BERTScore()
self.val_rouge_scores = ROUGEScore()

def _add_cross_attention_to_llm(self) -> None:
"""Add Cross-Attention layers to all decoder blocks."""
Expand Down Expand Up @@ -131,18 +130,17 @@ def validation_step(self, batch: CprtData, batch_idx: int) -> None:
self.val_perplexity.update(out["logits"], batch.labels)
input_text = self.text_tokenizer.batch_decode(batch.info, skip_special_tokens=True)
generated_text = self.text_tokenizer.batch_decode(torch.argmax(out["logits"], dim=-1), skip_special_tokens=True)
self.val_bleu_score.update(generated_text, input_text)
# self.val_bert_scores.update(generated_text, input_text)
self.val_rouge_scores.update(generated_text, input_text)
if batch_idx == 0:
print(generated_text)
torch.cuda.empty_cache()

def on_validation_epoch_end(self) -> None:
self.log("metrics/val_perplexity", self.val_perplexity.compute())
self.val_perplexity.reset()
self.log("metrics/val_bleu", self.val_bleu_score.compute())
self.val_bleu_score.reset()
# bert_score: Dict[str, Tensor] = self.val_bert_scores.compute() # type: ignore[assignment]
# self.log_dict({f"metrics/val_bert_{k}": v.mean() for k, v in bert_score.items()})
# self.val_bert_scores.reset()
rouge_scores: Dict[str, Tensor] = self.val_rouge_scores.compute()
self.log_dict({f"metrics/val_{k}": v.mean() for k, v in rouge_scores.items()})
self.val_bert_scores.reset()
for idx, layer in enumerate(self.cprt_llm.transformer.h):
self.log(f"gates/layer_{idx}_attn_gate", layer.attn_gate.item())
self.log(f"gates/layer_{idx}_ff_gate", layer.ff_gate.item())
Expand Down

0 comments on commit 984ac04

Please sign in to comment.