Passed
Push — master ( 47e1d1...83e87b )
by Jeremy
01:40
created

vit_model.Block.__init__()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 11
dl 0
loc 11
rs 9.9
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from functools import partial
2
3
import torch
4
import torch.nn as nn
5
from timm.layers import DropPath
6
7
8
class PatchEmbed(nn.Module):
9
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
10
        # img_size图像大小   patch_size每个图像块patch的大小  in_c 输入通道  embed_dim 嵌入维度  norm_layer 可选的归一化层
11
        super().__init__()
12
        img_size = (img_size, img_size)   # 将输入图像大小变为二维元组
13
        patch_size = (patch_size, patch_size)
14
        self.img_size = img_size
15
        self.patch_size = patch_size
16
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  # 224/16, 224/16  以patch为单位形成的新“图像”尺寸
17
        self.num_patches = self.grid_size[0] * self.grid_size[1]  # 14*14=196 patch总数
18
19
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 利用一个卷积核为16*16,步长为16大小进行卷积操作来等效实现将原图拆分成patch   B, 3, 224, 224 -> B, 768, 14, 14
20
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()  # 若存在norm layer则使用,否则保持不变
21
22
    def forward(self, x):
23
        B, C, H, W = x.shape   # 获取输入张量的形状
24
        assert H == self.img_size[0] and W == self.img_size[1],\
25
        f"输入图像大小{H} * {W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配"
26
        # B, 3, 224, 224 -> B, 768, 14, 14 -> B, 768, 196 -> B, 196, 768
27
        x = self.proj(x).flatten(2).transpose(1, 2)
28
        x = self.norm(x) # 使用norm层进行归一化
29
        return x
30
31
32
33
34
class Attention(nn.Module):
35
    # dim输入的token维度768, num_heads注意力头数,qkv_bias生成QKV的时候是否添加偏置,
36
    # qk_scale用于缩放QK的缩放因子,若为None,则使用1/sqrt(embed_dim_pre_head)
37
    # atte_drop_ration注意力分数的dropout的比率,防止过拟合  proj_drop_ration最终投影层的dropout的比率
38
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, atte_drop_ration=0., proj_drop_ration=0.):
39
        super().__init__()
40
        self.num_heads = num_heads
41
        head_dim = dim // num_heads # 每个注意力头的维度
42
        self.scale = qk_scale or head_dim ** -0.5  # qk的缩放因子
43
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过全连接层生成QKV,为了并行运算提高计算效率,同时参数更少
44
        self.attn_drop = nn.Dropout(atte_drop_ration)
45
        self.proj_drop = nn.Dropout(proj_drop_ration)
46
        # 将每个head得到的输出进行concat拼接,然后通过线性变换映射回原本的嵌入dim
47
        self.proj = nn.Linear(dim, dim, bias=qkv_bias)
48
49
    def forward(self, x):
50
        B, N, C = x.shape  # B为batch,N为num_patch+1,C为embed_dim  +1为clstoken
51
        #  B N 3*C -> B N 3 num_heads, C//self.num_heads -> 3 B num_heads N C//self.num_heads  作用是方便之后的运算
52
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53
        # 用切片拿到QKV,形状是 B num_heads N C//self.num_heads
54
        q, k, v = qkv[0], qkv[1], qkv[2]
55
        # 计算qk的点积,并进行缩放得到注意力分数
56
        # Q: [3 B num_heads N C//self.num_heads] k.transpose(-2,-1)  K:[B num_heads C//self.num_heads N]
57
        attn = (q @ k.transpose(-2, -1)) * self.scale  # B num_heads N N
58
        attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1
59
        # 注意力权重对V进行加权求和
60
        # attn @ v : B num_heads N C//self.num_heads
61
        # transpose: B N self.num_heads C//self.num_heads
62
        # reshape将最后两个维度拼接,合并多个头的输出,回到总的嵌入维度
63
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
64
        x = self.proj(x)
65
        x = self.proj_drop(x)
66
67
        return x
68
69
70
71
72
class Mlp(nn.Module):
73
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
74
        # in_features输入的维度, hidden_features隐藏层维度、通常为in_features的4倍, out_features输出维度、通常与输入维度相等
75
        super().__init__()
76
        out_features = out_features or in_features
77
        hidden_features = hidden_features or in_features
78
        self.fc1 = nn.Linear(in_features, hidden_features)
79
        self.act = act_layer()
80
        self.fc2 = nn.Linear(hidden_features, out_features)
81
        self.drop = nn.Dropout(drop)
82
83
    def forward(self, x):
84
        x = self.fc1(x)
85
        x = self.act(x)
86
        x = self.drop(x)
87
        x = self.fc2(x)
88
        x = self.drop(x)
89
        return x
90
91
92
93
94
95
class Block(nn.Module):
96
    # mlp_ratio 计算hidden_features大小 默认为输入4倍   norm_layer正则化层
97
    # drop_path_ratio 是drop_path的比率,该操作在残差连接之前  drop_ratio 是多头自注意力机制最后的linear后使用的dropout
98
99
    def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_ratio=0.,
100
                 attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
101
        super(Block, self).__init__()
102
        self.norm1 = norm_layer(dim)  # transformer encoder block中的第一个layer norm
103
        # 实例化多头注意力机制
104
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
105
                              atte_drop_ration=attn_drop_ratio, proj_drop_ration=drop_path_ratio)
106
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0 else nn.Identity()
107
        self.norm2 = norm_layer(dim)
108
        mlp_hidden_dim = int(dim * mlp_ratio)  # 计算MLP第一个全连接层的节点数
109
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
110
111
    def forward(self, x):
112
        x = x + self.drop_path(self.attn(self.norm1(x)))
113
        x = x + self.drop_path(self.mlp(self.norm2(x)))
114
        return x
115
116
117
118
class VisionTransformer(nn.Module):
119
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
120
                 embed_dim=768, depth=12, num_heads=12,mlp_ratio=4., qkv_bias=True, qk_scale=None,
121
                 representation_size=None, distilled=False, drop_ratio=0.,
122
                 attn_drop_ratio=0., drop_path_ratio=0. , embed_layer=PatchEmbed ,norm_layer=None,
123
                 act_layer=None):
124
        super(VisionTransformer, self).__init__()
125
        self.num_classes = num_classes
126
        self.num_features = self.embed_dim = embed_dim
127
        self.num_tokens = 2 if distilled else 1
128
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
129
130
131
132