Skip to content

Commit

Permalink
Merge pull request #34 from foreverpiano/main
Browse files Browse the repository at this point in the history
update attn_mask
  • Loading branch information
Kyriection authored Jun 18, 2024
2 parents 281ffef + d5962aa commit ac75c2a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion h2o_hf/utils_hh/modify_gptneox.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,14 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
else:
# activate historical best self.cache_budget - self.recent_budget tokens.
# self.previous_scores # (k-Cache - 1)
attn_mask[:, :] = 0
selected_set = self.previous_scores

if not self.heavy_budget == 0:
_, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True)
attn_mask = attn_mask.scatter(-1, keep_topk, 1)

self.attention_masks_next = attn_mask.unsqueeze(0).unsqueeze(2)
self.attention_masks_next = attn_mask.clone().unsqueeze(0).unsqueeze(2)
score_mask = attn_mask[:,:-1]
score_mask[:, -self.recent_budget:] = 1
self.previous_scores = self.previous_scores * score_mask
Expand Down
3 changes: 2 additions & 1 deletion h2o_hf/utils_hh/modify_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def forward(
else:
# activate historical best self.cache_budget - self.recent_budget tokens.
# self.previous_scores # (k-Cache - 1)
attn_mask[:, :] = 0
selected_set = self.previous_scores

if not self.heavy_budget == 0:
_, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True)
attn_mask = attn_mask.scatter(-1, keep_topk, 1)

self.attention_masks_next = attn_mask.unsqueeze(0).unsqueeze(2)
self.attention_masks_next = attn_mask.clone().unsqueeze(0).unsqueeze(2)

score_mask = attn_mask[:,:-1]
score_mask[:, -self.recent_budget:] = 1
Expand Down
3 changes: 2 additions & 1 deletion h2o_hf/utils_hh/modify_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,14 @@ def forward(
else:
# activate historical best self.cache_budget - self.recent_budget tokens.
# self.previous_scores # (k-Cache - 1)
attn_mask[:, :] = 0
selected_set = self.previous_scores

if not self.heavy_budget == 0:
_, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True)
attn_mask = attn_mask.scatter(-1, keep_topk, 1)

self.attention_masks_next = attn_mask.unsqueeze(1)
self.attention_masks_next = attn_mask.clone().unsqueeze(1)

score_mask = attn_mask[:,:-1]
score_mask[:, -self.recent_budget:] = 1
Expand Down

0 comments on commit ac75c2a

Please sign in to comment.