SwinTransformerBlock.__init__()   B
last analyzed

Complexity

Conditions 6

Size

Total Lines 55
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 43
nop 15
dl 0
loc 55
rs 7.9146
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import torch
2
import torch.nn as nn
3
import torch.utils.checkpoint as checkpoint
4
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
from transformers.models.clap.modeling_clap import window_partition
6
7
8
try:
9
    import os, sys
10
11
    kernel_path = os.path.abspath(os.path.join('..'))
12
    sys.path.append(kernel_path)
13
14
15
except:
16
    WindowProcess = None
17
    WindowProcessReverse = None
18
19
20
21
class MLP(nn.Module):
22
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
        super().__init__()
24
25
    def forward(self):
26
        return
27
28
29
30
# 2.窗口自注意力机制
31
class WindowAttention(nn.Module):  # 注意力头随着层次不同要发生变化,来保证每个头处理的维度数不变
32
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
33
        super().__init__()
34
        self.dim = dim
35
        self.window_size = window_size
36
        self.num_heads = num_heads
37
        head_dim = dim // num_heads
38
        self.scale = qk_scale or head_dim ** -0.5
39
40
        # 这里是初始化相对位置编码的偏置表, 2m-1 * 2m-1是因为x,y的取值范围均为2m-1,排列组合有这些数量
41
        self.relative_position_bias_table = nn.Parameter(
42
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
43
        )
44
45
        coords_h = torch.arange(self.window_size[0])
46
        coords_w = torch.arange(self.window_size[1])
47
        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # meshgrid生成了二维的网格坐标,用stack函数拼接起来
48
        coords_flatten = torch.flatten(coords, 1)  # 将二维的相对位置索引先展平
49
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 利用广播机制, 2 * h*w * 1 - 2 * 1 * h*w, 得到了原始的相对位置索引信息
50
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
51
        relative_coords[:, :, 0] += self.window_size[0] - 1
52
        relative_coords[:, :, 1] += self.window_size[1] - 1
53
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
54
        relative_position_index = relative_coords.sum(-1)  # 得到最终的二维相对位置索引
55
        self.register_buffer("relative_position_index", relative_position_index)  # 注册相对位置索引不需要学习
56
57
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58
        self.attn_drop = nn.Dropout(attn_drop)
59
        self.proj = nn.Linear(dim, dim)
60
        self.proj_drop = nn.Dropout(proj_drop)
61
62
        # 对元素值进行截断正态分布初始化,将在分布外的值消去,有助于模型训练的稳定性
63
        trunc_normal_(self.relative_position_bias_table, std=.02)
64
        self.softmax = nn.Softmax(dim=-1)
65
66
    def forward(self, x, mask=None):
67
        B_, N, C = x.shape
68
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
        q, k, v = qkv[0], qkv[1], qkv[2]
70
71
        q = q * self.scale
72
        attn = (q @ k.transpose(-2, -1))
73
74
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
75
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
76
        )
77
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
78
        attn = attn + relative_position_bias.unsqueeze(0)
79
80
        if mask is not None:
81
            nW = mask.shape[0]
82
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
83
            attn = attn.view(-1, self.num_heads, N, N)
84
            attn = self.softmax(attn)
85
        else:
86
            attn = self.softmax(attn)
87
88
        attn = self.attn_drop(attn)
89
90
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
91
        x = self.proj(x)
92
        x = self.proj_drop(x)
93
        return x
94
95
96
97
98
99
100
class SwinTransformerBlock(nn.Module):
101
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
102
                 qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
103
                 fused_window_process=False):
104
        super().__init__()
105
        self.dim = dim
106
        self.window_size = window_size
107
        self.num_heads = num_heads
108
        self.input_resolution = input_resolution
109
        self.shift_size = shift_size
110
        self.mlp_ratio = mlp_ratio
111
        # 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口
112
        if min(self.input_resolution) <= self.window_size:
113
            self.shift_size = 0
114
            self.window_size = min(self.input_resolution)
115
        assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size"
116
117
        self.norm1 = norm_layer(dim)
118
        self.attn = WindowAttention(
119
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
120
            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
121
        )
122
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
123
        self.norm2 = norm_layer(dim)
124
        mlp_hidden_dim = int(dim * mlp_ratio)
125
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
127
128
        # 3.mask部分实现
129
        if self.shift_size > 0:
130
            H, W = self.input_resolution
131
            img_mask = torch.zeros((1, H, W, 1))
132
            h_slices = (slice(0, -self.window_size),
133
                        slice(-self.window_size, -self.shift_size),
134
                        slice(-self.shift_size, None))
135
            w_slices = (slice(0, -self.window_size),
136
                        slice(-self.window_size, -self.shift_size),
137
                        slice(-self.shift_size, None))   # 相当于对输入切了三刀,可以参考示意图
138
139
            cnt = 0
140
            # 给不同的区域上数据做标号,参考示意图
141
            for h in h_slices:
142
                for w in w_slices:
143
                    img_mask[:, h, w, :] = cnt
144
                    cnt += 1
145
146
            # 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的)
147
            mask_windows = window_partition(img_mask, self.window_size)
148
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
149
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
150
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
151
        else:
152
            attn_mask = None
153
154
        self.register_buffer("attn_mask", attn_mask)
155
        self.fused_window_process = fused_window_process
156
157
    def forward(self, x):
158
        H, W = self.input_resolution
159
        B, L, C = x.shape
160
        assert L == H * W, "input feature has wrong size"
161
162
        shortcut = x
163
        x = self.norm1(x)
164
        x = x.view(B, H, W, C)
165
166
        if self.shift_size > 0:
167
            if not self.fused_window_process:
168
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
169
                x_windows = window_partition(shifted_x, self.window_size)
170
            else:
171
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
0 ignored issues
show
introduced by
The variable WindowProcess does not seem to be defined for all execution paths.
Loading history...
172
173
174
175
176
##1. patch embedding 实现
177
class PatchEmbed(nn.Module):
178
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
179
        super().__init__()
180
        img_size = to_2tuple(img_size)  # 将输入转为长度为2的元组 224*224
181
        patch_size = to_2tuple(patch_size)
182
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
183
        self.img_size = img_size
184
        self.patch_size = patch_size
185
        self.num_patches = patches_resolution[0] * patches_resolution[1]
186
187
        self.in_chans = in_chans
188
        self.embed_dim = embed_dim
189
190
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
191
        if norm_layer is not None:
192
            self.norm = norm_layer(embed_dim)
193
        else:
194
            self.norm = None
195
196
    def forward(self, x):
197
        B, C, H, W = x.shape  # 1*3*224*224
198
        assert H == self.img_size[0] and W == self.img_size[1],\
199
            f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})."
200
        x = self.proj(x).flatten(2).transpose(1, 2)   # 1*3136*96
201
        if self.norm is not None:
202
            x = self.norm(x)
203
        return x
204
205
    def flops(self):        # 统计浮点计算次数
206
        Ho, Wo = self.patches_resolution
207
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
208
        if self.norm is not None:
209
            flops += Ho * Wo * self.embed_dim
210
        return flops