-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy path(ACCV 2024) LIA.py
51 lines (44 loc) · 1.86 KB
/
(ACCV 2024) LIA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F
# 论文题目:PlainUSR: Chasing Faster ConvNet for Efficient Super-Resolution
# 论文地址:https://openaccess.thecvf.com/content/ACCV2024/papers/Wang_PlainUSR_Chasing_Faster_ConvNet_for_Efficient_Super-Resolution_ACCV_2024_paper.pdf
class SoftPooling2D(torch.nn.Module):
def __init__(self,kernel_size,stride=None,padding=0):
super(SoftPooling2D, self).__init__()
self.avgpool = torch.nn.AvgPool2d(kernel_size,stride,padding, count_include_pad=False)
def forward(self, x):
x_exp = torch.exp(x)
x_exp_pool = self.avgpool(x_exp)
x = self.avgpool(x_exp*x)
return x/x_exp_pool
class LocalAttention(nn.Module):
''' attention based on local importance'''
def __init__(self, channels, f=16):
super().__init__()
self.body = nn.Sequential(
# sample importance
nn.Conv2d(channels, f, 1),
SoftPooling2D(7, stride=3),
nn.Conv2d(f, f, kernel_size=3, stride=2, padding=1),
nn.Conv2d(f, channels, 3, padding=1),
# to heatmap
nn.Sigmoid(),
)
self.gate = nn.Sequential(
nn.Sigmoid(),
)
def forward(self, x):
''' forward '''
# interpolate the heat map
g = self.gate(x[:,:1].clone())
w = F.interpolate(self.body(x), (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
return x * w * g #(w + g) #self.gate(x, w)
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
block = LocalAttention(channels=32).to(device)
input = torch.rand(1, 32, 256, 256).to(device)
output = block(input)
print(input.shape)
print(output.shape)