|
1
|
|
|
import torch |
|
2
|
|
|
import torch.nn as nn |
|
3
|
|
|
import torch.utils.checkpoint as checkpoint |
|
4
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ |
|
5
|
|
|
from transformers.models.clap.modeling_clap import window_partition |
|
6
|
|
|
|
|
7
|
|
|
|
|
8
|
|
|
try: |
|
9
|
|
|
import os, sys |
|
10
|
|
|
|
|
11
|
|
|
kernel_path = os.path.abspath(os.path.join('..')) |
|
12
|
|
|
sys.path.append(kernel_path) |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
except: |
|
16
|
|
|
WindowProcess = None |
|
17
|
|
|
WindowProcessReverse = None |
|
18
|
|
|
|
|
19
|
|
|
|
|
20
|
|
|
|
|
21
|
|
|
class MLP(nn.Module): |
|
22
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
23
|
|
|
super().__init__() |
|
24
|
|
|
|
|
25
|
|
|
def forward(self): |
|
26
|
|
|
return |
|
27
|
|
|
|
|
28
|
|
|
|
|
29
|
|
|
|
|
30
|
|
|
# 2.窗口自注意力机制 |
|
31
|
|
|
class WindowAttention(nn.Module): # 注意力头随着层次不同要发生变化,来保证每个头处理的维度数不变 |
|
32
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
33
|
|
|
super().__init__() |
|
34
|
|
|
self.dim = dim |
|
35
|
|
|
self.window_size = window_size |
|
36
|
|
|
self.num_heads = num_heads |
|
37
|
|
|
head_dim = dim // num_heads |
|
38
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
|
39
|
|
|
|
|
40
|
|
|
# 这里是初始化相对位置编码的偏置表, 2m-1 * 2m-1是因为x,y的取值范围均为2m-1,排列组合有这些数量 |
|
41
|
|
|
self.relative_position_bias_table = nn.Parameter( |
|
42
|
|
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) |
|
43
|
|
|
) |
|
44
|
|
|
|
|
45
|
|
|
coords_h = torch.arange(self.window_size[0]) |
|
46
|
|
|
coords_w = torch.arange(self.window_size[1]) |
|
47
|
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # meshgrid生成了二维的网格坐标,用stack函数拼接起来 |
|
48
|
|
|
coords_flatten = torch.flatten(coords, 1) # 将二维的相对位置索引先展平 |
|
49
|
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 利用广播机制, 2 * h*w * 1 - 2 * 1 * h*w, 得到了原始的相对位置索引信息 |
|
50
|
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
|
51
|
|
|
relative_coords[:, :, 0] += self.window_size[0] - 1 |
|
52
|
|
|
relative_coords[:, :, 1] += self.window_size[1] - 1 |
|
53
|
|
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
|
54
|
|
|
relative_position_index = relative_coords.sum(-1) # 得到最终的二维相对位置索引 |
|
55
|
|
|
self.register_buffer("relative_position_index", relative_position_index) # 注册相对位置索引不需要学习 |
|
56
|
|
|
|
|
57
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
58
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
59
|
|
|
self.proj = nn.Linear(dim, dim) |
|
60
|
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
61
|
|
|
|
|
62
|
|
|
# 对元素值进行截断正态分布初始化,将在分布外的值消去,有助于模型训练的稳定性 |
|
63
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02) |
|
64
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
65
|
|
|
|
|
66
|
|
|
def forward(self, x, mask=None): |
|
67
|
|
|
B_, N, C = x.shape |
|
68
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
69
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
70
|
|
|
|
|
71
|
|
|
q = q * self.scale |
|
72
|
|
|
attn = (q @ k.transpose(-2, -1)) |
|
73
|
|
|
|
|
74
|
|
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
|
75
|
|
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 |
|
76
|
|
|
) |
|
77
|
|
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
|
78
|
|
|
attn = attn + relative_position_bias.unsqueeze(0) |
|
79
|
|
|
|
|
80
|
|
|
if mask is not None: |
|
81
|
|
|
nW = mask.shape[0] |
|
82
|
|
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
|
83
|
|
|
attn = attn.view(-1, self.num_heads, N, N) |
|
84
|
|
|
attn = self.softmax(attn) |
|
85
|
|
|
else: |
|
86
|
|
|
attn = self.softmax(attn) |
|
87
|
|
|
|
|
88
|
|
|
attn = self.attn_drop(attn) |
|
89
|
|
|
|
|
90
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
|
91
|
|
|
x = self.proj(x) |
|
92
|
|
|
x = self.proj_drop(x) |
|
93
|
|
|
return x |
|
94
|
|
|
|
|
95
|
|
|
|
|
96
|
|
|
|
|
97
|
|
|
|
|
98
|
|
|
|
|
99
|
|
|
|
|
100
|
|
|
class SwinTransformerBlock(nn.Module): |
|
101
|
|
|
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, |
|
102
|
|
|
qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
|
103
|
|
|
fused_window_process=False): |
|
104
|
|
|
super().__init__() |
|
105
|
|
|
self.dim = dim |
|
106
|
|
|
self.window_size = window_size |
|
107
|
|
|
self.num_heads = num_heads |
|
108
|
|
|
self.input_resolution = input_resolution |
|
109
|
|
|
self.shift_size = shift_size |
|
110
|
|
|
self.mlp_ratio = mlp_ratio |
|
111
|
|
|
# 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口 |
|
112
|
|
|
if min(self.input_resolution) <= self.window_size: |
|
113
|
|
|
self.shift_size = 0 |
|
114
|
|
|
self.window_size = min(self.input_resolution) |
|
115
|
|
|
assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size" |
|
116
|
|
|
|
|
117
|
|
|
self.norm1 = norm_layer(dim) |
|
118
|
|
|
self.attn = WindowAttention( |
|
119
|
|
|
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, |
|
120
|
|
|
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop |
|
121
|
|
|
) |
|
122
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() |
|
123
|
|
|
self.norm2 = norm_layer(dim) |
|
124
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
125
|
|
|
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
126
|
|
|
|
|
127
|
|
|
|
|
128
|
|
|
# 3.mask部分实现 |
|
129
|
|
|
if self.shift_size > 0: |
|
130
|
|
|
H, W = self.input_resolution |
|
131
|
|
|
img_mask = torch.zeros((1, H, W, 1)) |
|
132
|
|
|
h_slices = (slice(0, -self.window_size), |
|
133
|
|
|
slice(-self.window_size, -self.shift_size), |
|
134
|
|
|
slice(-self.shift_size, None)) |
|
135
|
|
|
w_slices = (slice(0, -self.window_size), |
|
136
|
|
|
slice(-self.window_size, -self.shift_size), |
|
137
|
|
|
slice(-self.shift_size, None)) # 相当于对输入切了三刀,可以参考示意图 |
|
138
|
|
|
|
|
139
|
|
|
cnt = 0 |
|
140
|
|
|
# 给不同的区域上数据做标号,参考示意图 |
|
141
|
|
|
for h in h_slices: |
|
142
|
|
|
for w in w_slices: |
|
143
|
|
|
img_mask[:, h, w, :] = cnt |
|
144
|
|
|
cnt += 1 |
|
145
|
|
|
|
|
146
|
|
|
# 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的) |
|
147
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
148
|
|
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
149
|
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
150
|
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
151
|
|
|
else: |
|
152
|
|
|
attn_mask = None |
|
153
|
|
|
|
|
154
|
|
|
self.register_buffer("attn_mask", attn_mask) |
|
155
|
|
|
self.fused_window_process = fused_window_process |
|
156
|
|
|
|
|
157
|
|
|
def forward(self, x): |
|
158
|
|
|
H, W = self.input_resolution |
|
159
|
|
|
B, L, C = x.shape |
|
160
|
|
|
assert L == H * W, "input feature has wrong size" |
|
161
|
|
|
|
|
162
|
|
|
shortcut = x |
|
163
|
|
|
x = self.norm1(x) |
|
164
|
|
|
x = x.view(B, H, W, C) |
|
165
|
|
|
|
|
166
|
|
|
if self.shift_size > 0: |
|
167
|
|
|
if not self.fused_window_process: |
|
168
|
|
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
|
169
|
|
|
x_windows = window_partition(shifted_x, self.window_size) |
|
170
|
|
|
else: |
|
171
|
|
|
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) |
|
|
|
|
|
|
172
|
|
|
|
|
173
|
|
|
|
|
174
|
|
|
|
|
175
|
|
|
|
|
176
|
|
|
##1. patch embedding 实现 |
|
177
|
|
|
class PatchEmbed(nn.Module): |
|
178
|
|
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
|
179
|
|
|
super().__init__() |
|
180
|
|
|
img_size = to_2tuple(img_size) # 将输入转为长度为2的元组 224*224 |
|
181
|
|
|
patch_size = to_2tuple(patch_size) |
|
182
|
|
|
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
|
183
|
|
|
self.img_size = img_size |
|
184
|
|
|
self.patch_size = patch_size |
|
185
|
|
|
self.num_patches = patches_resolution[0] * patches_resolution[1] |
|
186
|
|
|
|
|
187
|
|
|
self.in_chans = in_chans |
|
188
|
|
|
self.embed_dim = embed_dim |
|
189
|
|
|
|
|
190
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
191
|
|
|
if norm_layer is not None: |
|
192
|
|
|
self.norm = norm_layer(embed_dim) |
|
193
|
|
|
else: |
|
194
|
|
|
self.norm = None |
|
195
|
|
|
|
|
196
|
|
|
def forward(self, x): |
|
197
|
|
|
B, C, H, W = x.shape # 1*3*224*224 |
|
198
|
|
|
assert H == self.img_size[0] and W == self.img_size[1],\ |
|
199
|
|
|
f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})." |
|
200
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) # 1*3136*96 |
|
201
|
|
|
if self.norm is not None: |
|
202
|
|
|
x = self.norm(x) |
|
203
|
|
|
return x |
|
204
|
|
|
|
|
205
|
|
|
def flops(self): # 统计浮点计算次数 |
|
206
|
|
|
Ho, Wo = self.patches_resolution |
|
207
|
|
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) |
|
208
|
|
|
if self.norm is not None: |
|
209
|
|
|
flops += Ho * Wo * self.embed_dim |
|
210
|
|
|
return flops |