swin_transformer_model.WindowAttention.forward()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 28
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 21
nop 3
dl 0
loc 28
rs 9.376
c 0
b 0
f 0
1
import torch
2
import torch.nn as nn
3
import torch.utils.checkpoint as checkpoint
4
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
from transformers.models.clap.modeling_clap import window_partition
6
7
8
try:
9
    import os, sys
10
11
    kernel_path = os.path.abspath(os.path.join('..'))
12
    sys.path.append(kernel_path)
13
14
15
except:
16
    WindowProcess = None
17
    WindowProcessReverse = None
18
19
20
21
class MLP(nn.Module):
22
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
        super().__init__()
24
25
    def forward(self):
26
        return
27
28
29
30
# 2.窗口自注意力机制
31
class WindowAttention(nn.Module):  # 注意力头随着层次不同要发生变化,来保证每个头处理的维度数不变
32
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
33
        super().__init__()
34
        self.dim = dim
35
        self.window_size = window_size
36
        self.num_heads = num_heads
37
        head_dim = dim // num_heads
38
        self.scale = qk_scale or head_dim ** -0.5
39
40
        # 这里是初始化相对位置编码的偏置表, 2m-1 * 2m-1是因为x,y的取值范围均为2m-1,排列组合有这些数量
41
        self.relative_position_bias_table = nn.Parameter(
42
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
43
        )
44
45
        coords_h = torch.arange(self.window_size[0])
46
        coords_w = torch.arange(self.window_size[1])
47
        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # meshgrid生成了二维的网格坐标,用stack函数拼接起来
48
        coords_flatten = torch.flatten(coords, 1)  # 将二维的相对位置索引先展平
49
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 利用广播机制, 2 * h*w * 1 - 2 * 1 * h*w, 得到了原始的相对位置索引信息
50
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
51
        relative_coords[:, :, 0] += self.window_size[0] - 1
52
        relative_coords[:, :, 1] += self.window_size[1] - 1
53
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
54
        relative_position_index = relative_coords.sum(-1)  # 得到最终的二维相对位置索引
55
        self.register_buffer("relative_position_index", relative_position_index)  # 注册相对位置索引不需要学习
56
57
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58
        self.attn_drop = nn.Dropout(attn_drop)
59
        self.proj = nn.Linear(dim, dim)
60
        self.proj_drop = nn.Dropout(proj_drop)
61
62
        # 对元素值进行截断正态分布初始化,将在分布外的值消去,有助于模型训练的稳定性
63
        trunc_normal_(self.relative_position_bias_table, std=.02)
64
        self.softmax = nn.Softmax(dim=-1)
65
66
    def forward(self, x, mask=None):
67
        B_, N, C = x.shape
68
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
        q, k, v = qkv[0], qkv[1], qkv[2]
70
71
        q = q * self.scale
72
        attn = (q @ k.transpose(-2, -1))
73
74
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
75
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
76
        )
77
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
78
        attn = attn + relative_position_bias.unsqueeze(0)
79
80
        if mask is not None:
81
            nW = mask.shape[0]
82
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
83
            attn = attn.view(-1, self.num_heads, N, N)
84
            attn = self.softmax(attn)
85
        else:
86
            attn = self.softmax(attn)
87
88
        attn = self.attn_drop(attn)
89
90
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
91
        x = self.proj(x)
92
        x = self.proj_drop(x)
93
        return x
94
95
96
97
98
99
100
class SwinTransformerBlock(nn.Module):
101
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
102
                 qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
103
                 fused_window_process=False):
104
        super().__init__()
105
        self.dim = dim
106
        self.window_size = window_size
107
        self.num_heads = num_heads
108
        self.input_resolution = input_resolution
109
        self.shift_size = shift_size
110
        self.mlp_ratio = mlp_ratio
111
        # 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口
112
        if min(self.input_resolution) <= self.window_size:
113
            self.shift_size = 0
114
            self.window_size = min(self.input_resolution)
115
        assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size"
116
117
        self.norm1 = norm_layer(dim)
118
        self.attn = WindowAttention(
119
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
120
            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
121
        )
122
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
123
        self.norm2 = norm_layer(dim)
124
        mlp_hidden_dim = int(dim * mlp_ratio)
125
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
127
128
        # 3.mask部分实现
129
        if self.shift_size > 0:
130
            H, W = self.input_resolution
131
            img_mask = torch.zeros((1, H, W, 1))
132
            h_slices = (slice(0, -self.window_size),
133
                        slice(-self.window_size, -self.shift_size),
134
                        slice(-self.shift_size, None))
135
            w_slices = (slice(0, -self.window_size),
136
                        slice(-self.window_size, -self.shift_size),
137
                        slice(-self.shift_size, None))   # 相当于对输入切了三刀,可以参考示意图
138
139
            cnt = 0
140
            # 给不同的区域上数据做标号,参考示意图
141
            for h in h_slices:
142
                for w in w_slices:
143
                    img_mask[:, h, w, :] = cnt
144
                    cnt += 1
145
146
            # 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的)
147
            mask_windows = window_partition(img_mask, self.window_size)
148
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
149
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
150
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
151
        else:
152
            attn_mask = None
153
154
        self.register_buffer("attn_mask", attn_mask)
155
        self.fused_window_process = fused_window_process
156
157
    def forward(self, x):
158
        H, W = self.input_resolution
159
        B, L, C = x.shape
160
        assert L == H * W, "input feature has wrong size"
161
162
        shortcut = x
163
        x = self.norm1(x)
164
        x = x.view(B, H, W, C)
165
166
        if self.shift_size > 0:
167
            if not self.fused_window_process:
168
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
169
                x_windows = window_partition(shifted_x, self.window_size)
170
            else:
171
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
0 ignored issues
show
introduced by
The variable WindowProcess does not seem to be defined for all execution paths.
Loading history...
172
173
174
175
176
##1. patch embedding 实现
177
class PatchEmbed(nn.Module):
178
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
179
        super().__init__()
180
        img_size = to_2tuple(img_size)  # 将输入转为长度为2的元组 224*224
181
        patch_size = to_2tuple(patch_size)
182
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
183
        self.img_size = img_size
184
        self.patch_size = patch_size
185
        self.num_patches = patches_resolution[0] * patches_resolution[1]
186
187
        self.in_chans = in_chans
188
        self.embed_dim = embed_dim
189
190
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
191
        if norm_layer is not None:
192
            self.norm = norm_layer(embed_dim)
193
        else:
194
            self.norm = None
195
196
    def forward(self, x):
197
        B, C, H, W = x.shape  # 1*3*224*224
198
        assert H == self.img_size[0] and W == self.img_size[1],\
199
            f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})."
200
        x = self.proj(x).flatten(2).transpose(1, 2)   # 1*3136*96
201
        if self.norm is not None:
202
            x = self.norm(x)
203
        return x
204
205
    def flops(self):        # 统计浮点计算次数
206
        Ho, Wo = self.patches_resolution
207
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
208
        if self.norm is not None:
209
            flops += Ho * Wo * self.embed_dim
210
        return flops