diff --git a/h2o_hf/utils_hh/modify_gptneox.py b/h2o_hf/utils_hh/modify_gptneox.py index 96cb495..dc2df97 100644 --- a/h2o_hf/utils_hh/modify_gptneox.py +++ b/h2o_hf/utils_hh/modify_gptneox.py @@ -222,13 +222,14 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: # activate historical best self.cache_budget - self.recent_budget tokens. # self.previous_scores # (k-Cache - 1) + attn_mask[:, :] = 0 selected_set = self.previous_scores if not self.heavy_budget == 0: _, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True) attn_mask = attn_mask.scatter(-1, keep_topk, 1) - self.attention_masks_next = attn_mask.unsqueeze(0).unsqueeze(2) + self.attention_masks_next = attn_mask.clone().unsqueeze(0).unsqueeze(2) score_mask = attn_mask[:,:-1] score_mask[:, -self.recent_budget:] = 1 self.previous_scores = self.previous_scores * score_mask diff --git a/h2o_hf/utils_hh/modify_llama.py b/h2o_hf/utils_hh/modify_llama.py index 8ec5533..ea965cd 100644 --- a/h2o_hf/utils_hh/modify_llama.py +++ b/h2o_hf/utils_hh/modify_llama.py @@ -149,13 +149,14 @@ def forward( else: # activate historical best self.cache_budget - self.recent_budget tokens. # self.previous_scores # (k-Cache - 1) + attn_mask[:, :] = 0 selected_set = self.previous_scores if not self.heavy_budget == 0: _, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True) attn_mask = attn_mask.scatter(-1, keep_topk, 1) - self.attention_masks_next = attn_mask.unsqueeze(0).unsqueeze(2) + self.attention_masks_next = attn_mask.clone().unsqueeze(0).unsqueeze(2) score_mask = attn_mask[:,:-1] score_mask[:, -self.recent_budget:] = 1 diff --git a/h2o_hf/utils_hh/modify_opt.py b/h2o_hf/utils_hh/modify_opt.py index df455ee..29f70cd 100644 --- a/h2o_hf/utils_hh/modify_opt.py +++ b/h2o_hf/utils_hh/modify_opt.py @@ -180,13 +180,14 @@ def forward( else: # activate historical best self.cache_budget - self.recent_budget tokens. # self.previous_scores # (k-Cache - 1) + attn_mask[:, :] = 0 selected_set = self.previous_scores if not self.heavy_budget == 0: _, keep_topk = selected_set.topk(k=self.heavy_budget, dim=-1, largest=True) attn_mask = attn_mask.scatter(-1, keep_topk, 1) - self.attention_masks_next = attn_mask.unsqueeze(1) + self.attention_masks_next = attn_mask.clone().unsqueeze(1) score_mask = attn_mask[:,:-1] score_mask[:, -self.recent_budget:] = 1