Passed
Push — master ( aa775a...47e1d1 )
by Jeremy
01:51 queued 14s
created

vit_model.Attention.forward()   A

Complexity

Conditions 1

Size

Total Lines 19
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 10
nop 2
dl 0
loc 19
rs 9.9
c 0
b 0
f 0
1
import torch
2
import torch.nn as nn
3
4
class PatchEmbed(nn.Module):
5
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
6
        # img_size图像大小   patch_size每个图像块patch的大小  in_c 输入通道  embed_dim 嵌入维度  norm_layer 可选的归一化层
7
        super().__init__()
8
        img_size = (img_size, img_size)   # 将输入图像大小变为二维元组
9
        patch_size = (patch_size, patch_size)
10
        self.img_size = img_size
11
        self.patch_size = patch_size
12
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  # 224/16, 224/16  以patch为单位形成的新“图像”尺寸
13
        self.num_patches = self.grid_size[0] * self.grid_size[1]  # 14*14=196 patch总数
14
15
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 利用一个卷积核为16*16,步长为16大小进行卷积操作来等效实现将原图拆分成patch   B, 3, 224, 224 -> B, 768, 14, 14
16
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()  # 若存在norm layer则使用,否则保持不变
17
18
    def forward(self, x):
19
        B, C, H, W = x.shape   # 获取输入张量的形状
20
        assert H == self.img_size[0] and W == self.img_size[1],\
21
        f"输入图像大小{H} * {W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配"
22
        # B, 3, 224, 224 -> B, 768, 14, 14 -> B, 768, 196 -> B, 196, 768
23
        x = self.proj(x).flatten(2).transpose(1, 2)
24
        x = self.norm(x) # 使用norm层进行归一化
25
        return x
26
27
28
29
30
class Attention(nn.Module):
31
    # dim输入的token维度768, num_heads注意力头数,qkv_bias生成QKV的时候是否添加偏置,
32
    # qk_scale用于缩放QK的缩放因子,若为None,则使用1/sqrt(embed_dim_pre_head)
33
    # atte_drop_ration注意力分数的dropout的比率,防止过拟合  proj_drop_ration最终投影层的dropout的比率
34
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, atte_drop_ration=0., proj_drop_ration=0.):
35
        super().__init__()
36
        self.num_heads = num_heads
37
        head_dim = dim // num_heads # 每个注意力头的维度
38
        self.scale = qk_scale or head_dim ** -0.5  # qk的缩放因子
39
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过全连接层生成QKV,为了并行运算提高计算效率,同时参数更少
40
        self.attn_drop = nn.Dropout(atte_drop_ration)
41
        self.proj_drop = nn.Dropout(proj_drop_ration)
42
        # 将每个head得到的输出进行concat拼接,然后通过线性变换映射回原本的嵌入dim
43
        self.proj = nn.Linear(dim, dim, bias=qkv_bias)
44
45
    def forward(self, x):
46
        B, N, C = x.shape  # B为batch,N为num_patch+1,C为embed_dim  +1为clstoken
47
        #  B N 3*C -> B N 3 num_heads, C//self.num_heads -> 3 B num_heads N C//self.num_heads  作用是方便之后的运算
48
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
49
        # 用切片拿到QKV,形状是 B num_heads N C//self.num_heads
50
        q, k, v = qkv[0], qkv[1], qkv[2]
51
        # 计算qk的点积,并进行缩放得到注意力分数
52
        # Q: [3 B num_heads N C//self.num_heads] k.transpose(-2,-1)  K:[B num_heads C//self.num_heads N]
53
        attn = (q @ k.transpose(-2, -1)) * self.scale  # B num_heads N N
54
        attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1
55
        # 注意力权重对V进行加权求和
56
        # attn @ v : B num_heads N C//self.num_heads
57
        # transpose: B N self.num_heads C//self.num_heads
58
        # reshape将最后两个维度拼接,合并多个头的输出,回到总的嵌入维度
59
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
        x = self.proj(x)
61
        x = self.proj_drop(x)
62
63
        return x
64
65
66
class Mlp(nn.Module):
67
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
68
        super().__init__()
69
70
71