Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EMCarrami committed Nov 9, 2023
1 parent d4399da commit 9eab60c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions cprt/model/cprt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def on_validation_epoch_end(self) -> None:
self.log_dict({f"metrics/val_{k}": v.mean() for k, v in rouge_scores.items()})
self.val_rouge_scores.reset()
for idx, layer in enumerate(self.cprt_llm.transformer.h):
self.log(f"gates/layer_{idx}_attn_gate", layer.attn_gate.item())
self.log(f"gates/layer_{idx}_ff_gate", layer.ff_gate.item())
self.log(f"gates/layer_{idx}_attn_gate", layer.cross_attn.attn_gate.item())
self.log(f"gates/layer_{idx}_ff_gate", layer.cross_attn.ff_gate.item())
# log example outputs
input_text = self.text_tokenizer.batch_decode(self.last_val_batch.info, skip_special_tokens=True)
for in_txt, protein in zip(input_text, self.last_val_batch.protein):
Expand Down
3 changes: 2 additions & 1 deletion cprt/model/helper_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
Expand All @@ -73,7 +74,7 @@ def forward(
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
Expand Down

0 comments on commit 9eab60c

Please sign in to comment.