-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy path(ICCV 2021) RA.py
35 lines (23 loc) · 927 Bytes
/
(ICCV 2021) RA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 论文地址:https://arxiv.org/pdf/2108.02456
# 论文:Residual Attention: A Simple but Effective Method for Multi-Label Recognition
class ResidualAttention(nn.Module):
def __init__(self, channel=512 , num_class=1000,la=0.2):
super().__init__()
self.la=la
self.fc=nn.Conv2d(in_channels=channel,out_channels=num_class,kernel_size=1,stride=1,bias=False)
def forward(self, x):
b,c,h,w=x.shape
y_raw=self.fc(x).flatten(2) #b,num_class,hxw
y_avg=torch.mean(y_raw,dim=2) #b,num_class
y_max=torch.max(y_raw,dim=2)[0] #b,num_class
score=y_avg+self.la*y_max
return score
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
resatt = ResidualAttention(channel=512,num_class=1000,la=0.2)
output=resatt(input)
print(output.shape)