1
|
|
|
import torch |
2
|
|
|
import torch.nn as nn |
3
|
|
|
import torch.utils.checkpoint as checkpoint |
4
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ |
5
|
|
|
from transformers.models.clap.modeling_clap import window_partition |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
try: |
9
|
|
|
import os, sys |
10
|
|
|
|
11
|
|
|
kernel_path = os.path.abspath(os.path.join('..')) |
12
|
|
|
sys.path.append(kernel_path) |
13
|
|
|
|
14
|
|
|
|
15
|
|
|
except: |
16
|
|
|
WindowProcess = None |
17
|
|
|
WindowProcessReverse = None |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
|
21
|
|
|
class MLP(nn.Module): |
22
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
23
|
|
|
super().__init__() |
24
|
|
|
|
25
|
|
|
def forward(self): |
26
|
|
|
return |
27
|
|
|
|
28
|
|
|
|
29
|
|
|
|
30
|
|
|
# 2.窗口自注意力机制 |
31
|
|
|
class WindowAttention(nn.Module): # 注意力头随着层次不同要发生变化,来保证每个头处理的维度数不变 |
32
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
33
|
|
|
super().__init__() |
34
|
|
|
self.dim = dim |
35
|
|
|
self.window_size = window_size |
36
|
|
|
self.num_heads = num_heads |
37
|
|
|
head_dim = dim // num_heads |
38
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
39
|
|
|
|
40
|
|
|
# 这里是初始化相对位置编码的偏置表, 2m-1 * 2m-1是因为x,y的取值范围均为2m-1,排列组合有这些数量 |
41
|
|
|
self.relative_position_bias_table = nn.Parameter( |
42
|
|
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) |
43
|
|
|
) |
44
|
|
|
|
45
|
|
|
coords_h = torch.arange(self.window_size[0]) |
46
|
|
|
coords_w = torch.arange(self.window_size[1]) |
47
|
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # meshgrid生成了二维的网格坐标,用stack函数拼接起来 |
48
|
|
|
coords_flatten = torch.flatten(coords, 1) # 将二维的相对位置索引先展平 |
49
|
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 利用广播机制, 2 * h*w * 1 - 2 * 1 * h*w, 得到了原始的相对位置索引信息 |
50
|
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
51
|
|
|
relative_coords[:, :, 0] += self.window_size[0] - 1 |
52
|
|
|
relative_coords[:, :, 1] += self.window_size[1] - 1 |
53
|
|
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
54
|
|
|
relative_position_index = relative_coords.sum(-1) # 得到最终的二维相对位置索引 |
55
|
|
|
self.register_buffer("relative_position_index", relative_position_index) # 注册相对位置索引不需要学习 |
56
|
|
|
|
57
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
58
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
59
|
|
|
self.proj = nn.Linear(dim, dim) |
60
|
|
|
self.proj_drop = nn.Dropout(proj_drop) |
61
|
|
|
|
62
|
|
|
# 对元素值进行截断正态分布初始化,将在分布外的值消去,有助于模型训练的稳定性 |
63
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02) |
64
|
|
|
self.softmax = nn.Softmax(dim=-1) |
65
|
|
|
|
66
|
|
|
def forward(self, x, mask=None): |
67
|
|
|
B_, N, C = x.shape |
68
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
69
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
70
|
|
|
|
71
|
|
|
q = q * self.scale |
72
|
|
|
attn = (q @ k.transpose(-2, -1)) |
73
|
|
|
|
74
|
|
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
75
|
|
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 |
76
|
|
|
) |
77
|
|
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
78
|
|
|
attn = attn + relative_position_bias.unsqueeze(0) |
79
|
|
|
|
80
|
|
|
if mask is not None: |
81
|
|
|
nW = mask.shape[0] |
82
|
|
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
83
|
|
|
attn = attn.view(-1, self.num_heads, N, N) |
84
|
|
|
attn = self.softmax(attn) |
85
|
|
|
else: |
86
|
|
|
attn = self.softmax(attn) |
87
|
|
|
|
88
|
|
|
attn = self.attn_drop(attn) |
89
|
|
|
|
90
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
91
|
|
|
x = self.proj(x) |
92
|
|
|
x = self.proj_drop(x) |
93
|
|
|
return x |
94
|
|
|
|
95
|
|
|
|
96
|
|
|
|
97
|
|
|
|
98
|
|
|
|
99
|
|
|
|
100
|
|
|
class SwinTransformerBlock(nn.Module): |
101
|
|
|
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, |
102
|
|
|
qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
103
|
|
|
fused_window_process=False): |
104
|
|
|
super().__init__() |
105
|
|
|
self.dim = dim |
106
|
|
|
self.window_size = window_size |
107
|
|
|
self.num_heads = num_heads |
108
|
|
|
self.input_resolution = input_resolution |
109
|
|
|
self.shift_size = shift_size |
110
|
|
|
self.mlp_ratio = mlp_ratio |
111
|
|
|
# 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口 |
112
|
|
|
if min(self.input_resolution) <= self.window_size: |
113
|
|
|
self.shift_size = 0 |
114
|
|
|
self.window_size = min(self.input_resolution) |
115
|
|
|
assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size" |
116
|
|
|
|
117
|
|
|
self.norm1 = norm_layer(dim) |
118
|
|
|
self.attn = WindowAttention( |
119
|
|
|
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, |
120
|
|
|
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop |
121
|
|
|
) |
122
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() |
123
|
|
|
self.norm2 = norm_layer(dim) |
124
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
125
|
|
|
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
126
|
|
|
|
127
|
|
|
|
128
|
|
|
# 3.mask部分实现 |
129
|
|
|
if self.shift_size > 0: |
130
|
|
|
H, W = self.input_resolution |
131
|
|
|
img_mask = torch.zeros((1, H, W, 1)) |
132
|
|
|
h_slices = (slice(0, -self.window_size), |
133
|
|
|
slice(-self.window_size, -self.shift_size), |
134
|
|
|
slice(-self.shift_size, None)) |
135
|
|
|
w_slices = (slice(0, -self.window_size), |
136
|
|
|
slice(-self.window_size, -self.shift_size), |
137
|
|
|
slice(-self.shift_size, None)) # 相当于对输入切了三刀,可以参考示意图 |
138
|
|
|
|
139
|
|
|
cnt = 0 |
140
|
|
|
# 给不同的区域上数据做标号,参考示意图 |
141
|
|
|
for h in h_slices: |
142
|
|
|
for w in w_slices: |
143
|
|
|
img_mask[:, h, w, :] = cnt |
144
|
|
|
cnt += 1 |
145
|
|
|
|
146
|
|
|
# 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的) |
147
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
148
|
|
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
149
|
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
150
|
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
151
|
|
|
else: |
152
|
|
|
attn_mask = None |
153
|
|
|
|
154
|
|
|
self.register_buffer("attn_mask", attn_mask) |
155
|
|
|
self.fused_window_process = fused_window_process |
156
|
|
|
|
157
|
|
|
def forward(self, x): |
158
|
|
|
H, W = self.input_resolution |
159
|
|
|
B, L, C = x.shape |
160
|
|
|
assert L == H * W, "input feature has wrong size" |
161
|
|
|
|
162
|
|
|
shortcut = x |
163
|
|
|
x = self.norm1(x) |
164
|
|
|
x = x.view(B, H, W, C) |
165
|
|
|
|
166
|
|
|
if self.shift_size > 0: |
167
|
|
|
if not self.fused_window_process: |
168
|
|
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
169
|
|
|
x_windows = window_partition(shifted_x, self.window_size) |
170
|
|
|
else: |
171
|
|
|
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) |
|
|
|
|
172
|
|
|
|
173
|
|
|
|
174
|
|
|
|
175
|
|
|
|
176
|
|
|
##1. patch embedding 实现 |
177
|
|
|
class PatchEmbed(nn.Module): |
178
|
|
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
179
|
|
|
super().__init__() |
180
|
|
|
img_size = to_2tuple(img_size) # 将输入转为长度为2的元组 224*224 |
181
|
|
|
patch_size = to_2tuple(patch_size) |
182
|
|
|
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
183
|
|
|
self.img_size = img_size |
184
|
|
|
self.patch_size = patch_size |
185
|
|
|
self.num_patches = patches_resolution[0] * patches_resolution[1] |
186
|
|
|
|
187
|
|
|
self.in_chans = in_chans |
188
|
|
|
self.embed_dim = embed_dim |
189
|
|
|
|
190
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
191
|
|
|
if norm_layer is not None: |
192
|
|
|
self.norm = norm_layer(embed_dim) |
193
|
|
|
else: |
194
|
|
|
self.norm = None |
195
|
|
|
|
196
|
|
|
def forward(self, x): |
197
|
|
|
B, C, H, W = x.shape # 1*3*224*224 |
198
|
|
|
assert H == self.img_size[0] and W == self.img_size[1],\ |
199
|
|
|
f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})." |
200
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) # 1*3136*96 |
201
|
|
|
if self.norm is not None: |
202
|
|
|
x = self.norm(x) |
203
|
|
|
return x |
204
|
|
|
|
205
|
|
|
def flops(self): # 统计浮点计算次数 |
206
|
|
|
Ho, Wo = self.patches_resolution |
207
|
|
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) |
208
|
|
|
if self.norm is not None: |
209
|
|
|
flops += Ho * Wo * self.embed_dim |
210
|
|
|
return flops |