1
|
|
|
import torch |
2
|
|
|
import torch.nn as nn |
3
|
|
|
|
4
|
|
|
class PatchEmbed(nn.Module): |
5
|
|
|
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): |
6
|
|
|
# img_size图像大小 patch_size每个图像块patch的大小 in_c 输入通道 embed_dim 嵌入维度 norm_layer 可选的归一化层 |
7
|
|
|
super().__init__() |
8
|
|
|
img_size = (img_size, img_size) # 将输入图像大小变为二维元组 |
9
|
|
|
patch_size = (patch_size, patch_size) |
10
|
|
|
self.img_size = img_size |
11
|
|
|
self.patch_size = patch_size |
12
|
|
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # 224/16, 224/16 以patch为单位形成的新“图像”尺寸 |
13
|
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1] # 14*14=196 patch总数 |
14
|
|
|
|
15
|
|
|
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 利用一个卷积核为16*16,步长为16大小进行卷积操作来等效实现将原图拆分成patch B, 3, 224, 224 -> B, 768, 14, 14 |
16
|
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() # 若存在norm layer则使用,否则保持不变 |
17
|
|
|
|
18
|
|
|
def forward(self, x): |
19
|
|
|
B, C, H, W = x.shape # 获取输入张量的形状 |
20
|
|
|
assert H == self.img_size[0] and W == self.img_size[1],\ |
21
|
|
|
f"输入图像大小{H} * {W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配" |
22
|
|
|
# B, 3, 224, 224 -> B, 768, 14, 14 -> B, 768, 196 -> B, 196, 768 |
23
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) |
24
|
|
|
x = self.norm(x) # 使用norm层进行归一化 |
25
|
|
|
return x |
26
|
|
|
|
27
|
|
|
|
28
|
|
|
|
29
|
|
|
|
30
|
|
|
class Attention(nn.Module): |
31
|
|
|
# dim输入的token维度768, num_heads注意力头数,qkv_bias生成QKV的时候是否添加偏置, |
32
|
|
|
# qk_scale用于缩放QK的缩放因子,若为None,则使用1/sqrt(embed_dim_pre_head) |
33
|
|
|
# atte_drop_ration注意力分数的dropout的比率,防止过拟合 proj_drop_ration最终投影层的dropout的比率 |
34
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, atte_drop_ration=0., proj_drop_ration=0.): |
35
|
|
|
super().__init__() |
36
|
|
|
self.num_heads = num_heads |
37
|
|
|
head_dim = dim // num_heads # 每个注意力头的维度 |
38
|
|
|
self.scale = qk_scale or head_dim ** -0.5 # qk的缩放因子 |
39
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过全连接层生成QKV,为了并行运算提高计算效率,同时参数更少 |
40
|
|
|
self.attn_drop = nn.Dropout(atte_drop_ration) |
41
|
|
|
self.proj_drop = nn.Dropout(proj_drop_ration) |
42
|
|
|
# 将每个head得到的输出进行concat拼接,然后通过线性变换映射回原本的嵌入dim |
43
|
|
|
self.proj = nn.Linear(dim, dim, bias=qkv_bias) |
44
|
|
|
|
45
|
|
|
def forward(self, x): |
46
|
|
|
B, N, C = x.shape # B为batch,N为num_patch+1,C为embed_dim +1为clstoken |
47
|
|
|
# B N 3*C -> B N 3 num_heads, C//self.num_heads -> 3 B num_heads N C//self.num_heads 作用是方便之后的运算 |
48
|
|
|
qkv = self.qkv(x).reshape(B,N,3,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
49
|
|
|
# 用切片拿到QKV,形状是 B num_heads N C//self.num_heads |
50
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
51
|
|
|
# 计算qk的点积,并进行缩放得到注意力分数 |
52
|
|
|
# Q: [3 B num_heads N C//self.num_heads] k.transpose(-2,-1) K:[B num_heads C//self.num_heads N] |
53
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale # B num_heads N N |
54
|
|
|
attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1 |
55
|
|
|
# 注意力权重对V进行加权求和 |
56
|
|
|
# attn @ v : B num_heads N C//self.num_heads |
57
|
|
|
# transpose: B N self.num_heads C//self.num_heads |
58
|
|
|
# reshape将最后两个维度拼接,合并多个头的输出,回到总的嵌入维度 |
59
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
60
|
|
|
x = self.proj(x) |
61
|
|
|
x = self.proj_drop(x) |
62
|
|
|
|
63
|
|
|
return x |
64
|
|
|
|
65
|
|
|
|
66
|
|
|
class Mlp(nn.Module): |
67
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
68
|
|
|
super().__init__() |
69
|
|
|
|
70
|
|
|
|
71
|
|
|
|