Skip to content

Commit

Permalink
refactoring (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
kssteven418 authored Dec 23, 2021
1 parent 451c6e3 commit 696138b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 14 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ Add the following lines in the configuration file `{CKPT}/config.json`.
```
"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01,
"scoring_mode": "mean",
```

`final_token_threshold` determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled.
Expand Down Expand Up @@ -116,7 +115,6 @@ Add the following lines in `{CKPT}/config.json`.
```
"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01,
"scoring_mode": "mean",
```

Run the following command:
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/ltp/configuration_ltp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(
prune_mode=None,
token_keep_rate=1,
token_threshold=0,
scoring_mode='mean',
final_token_threshold=0,
**kwargs
):
Expand Down Expand Up @@ -152,7 +151,5 @@ def __init__(
self.final_token_threshold = final_token_threshold
elif self.prune_mode == 'absolute_threshold':
self.prune_kwargs = {'final_token_threshold': final_token_threshold,
'num_hidden_layers': num_hidden_layers,
'scoring_mode': scoring_mode}
'num_hidden_layers': num_hidden_layers}
self.final_token_threshold = final_token_threshold
self.scoring_mode = scoring_mode
10 changes: 2 additions & 8 deletions src/transformers/models/ltp/prune_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,13 @@ class AbsoluteThresholdTokenPruner(AbstractTokenPruner):
implements the layer-by-layer operations for threshold token pruning, where tokens are pruned if the importance
score is strictly less than a given fraction of the maximum token importance score
"""
def __init__(self, module_num, final_token_threshold=None, num_hidden_layers=None, scoring_mode='mean', **kwargs):
def __init__(self, module_num, final_token_threshold=None, num_hidden_layers=None, **kwargs):
super().__init__()
self.keep_threshold_base = torch.tensor(final_token_threshold * module_num / num_hidden_layers, device='cuda')
self.keep_threshold = nn.Parameter(
torch.zeros_like(self.keep_threshold_base, device='cuda'),
requires_grad=True,
)
self.scoring_mode = scoring_mode
self.module_num = module_num

logger.info("Layer %d Threshold: %f" % (module_num, float(self.keep_threshold_base + self.keep_threshold)))
Expand All @@ -139,12 +138,7 @@ def update_attention_mask(self, attention_mask, attention_probs, sentence_length
# compute the pruning scores by summing the attention probabilities over all heads
attention_mask_index = (attention_mask < 0).permute(0, 1, 3, 2).repeat(1, attention_probs.shape[1], 1, sz)
attention_probs[attention_mask_index] = 0
if self.scoring_mode == 'mean':
pruning_scores = attention_probs.view(batch_size, -1, sz).mean(dim=1)
elif self.scoring_mode == 'mean_norm':
seqlen = (attention_mask.squeeze(1).squeeze(1) == 0).sum(-1)
pruning_scores = attention_probs.sum(2) / seqlen.reshape(-1, 1, 1) # BS, H, SL
pruning_scores = pruning_scores.mean(1)
pruning_scores = attention_probs.view(batch_size, -1, sz).mean(dim=1)

new_attention_mask = torch.zeros(attention_mask.shape, device=attention_mask.device)
new_attention_mask[pruning_scores.unsqueeze(1).unsqueeze(1) < max(1e-5, keep_threshold)] = -10000
Expand Down

0 comments on commit 696138b

Please sign in to comment.