| 1 |  |  | from functools import partial | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import torch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import torch.nn as nn | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from timm.layers import DropPath | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | class PatchEmbed(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |         # img_size图像大小   patch_size每个图像块patch的大小  in_c 输入通道  embed_dim 嵌入维度  norm_layer 可选的归一化层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |         img_size = (img_size, img_size)   # 将输入图像大小变为二维元组 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |         patch_size = (patch_size, patch_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |         self.img_size = img_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |         self.patch_size = patch_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |         self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  # 224/16, 224/16  以patch为单位形成的新“图像”尺寸 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |         self.num_patches = self.grid_size[0] * self.grid_size[1]  # 14*14=196 patch总数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 利用一个卷积核为16*16,步长为16大小进行卷积操作来等效实现将原图拆分成patch   B, 3, 224, 224 -> B, 768, 14, 14 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()  # 若存在norm layer则使用,否则保持不变 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         B, C, H, W = x.shape   # 获取输入张量的形状 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         assert H == self.img_size[0] and W == self.img_size[1],\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         f"输入图像大小{H} * {W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         # B, 3, 224, 224 -> B, 768, 14, 14 -> B, 768, 196 -> B, 196, 768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         x = self.proj(x).flatten(2).transpose(1, 2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         x = self.norm(x) # 使用norm层进行归一化 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  | class Attention(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     # dim输入的token维度768, num_heads注意力头数,qkv_bias生成QKV的时候是否添加偏置, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     # qk_scale用于缩放QK的缩放因子,若为None,则使用1/sqrt(embed_dim_pre_head) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     # atte_drop_ration注意力分数的dropout的比率,防止过拟合  proj_drop_ration最终投影层的dropout的比率 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, atte_drop_ration=0., proj_drop_ration=0.): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         self.num_heads = num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         head_dim = dim // num_heads # 每个注意力头的维度 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         self.scale = qk_scale or head_dim ** -0.5  # qk的缩放因子 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过全连接层生成QKV,为了并行运算提高计算效率,同时参数更少 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         self.attn_drop = nn.Dropout(atte_drop_ration) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         self.proj_drop = nn.Dropout(proj_drop_ration) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         # 将每个head得到的输出进行concat拼接,然后通过线性变换映射回原本的嵌入dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         self.proj = nn.Linear(dim, dim, bias=qkv_bias) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 48 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 49 |  |  |     def forward(self, x): | 
            
                                                                        
                            
            
                                    
            
            
                | 50 |  |  |         B, N, C = x.shape  # B为batch,N为num_patch+1,C为embed_dim  +1为clstoken | 
            
                                                                        
                            
            
                                    
            
            
                | 51 |  |  |         #  B N 3*C -> B N 3 num_heads, C//self.num_heads -> 3 B num_heads N C//self.num_heads  作用是方便之后的运算 | 
            
                                                                        
                            
            
                                    
            
            
                | 52 |  |  |         qkv = self.qkv(x).reshape(B,N,3,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | 
            
                                                                        
                            
            
                                    
            
            
                | 53 |  |  |         # 用切片拿到QKV,形状是 B num_heads N C//self.num_heads | 
            
                                                                        
                            
            
                                    
            
            
                | 54 |  |  |         q, k, v = qkv[0], qkv[1], qkv[2] | 
            
                                                                        
                            
            
                                    
            
            
                | 55 |  |  |         # 计算qk的点积,并进行缩放得到注意力分数 | 
            
                                                                        
                            
            
                                    
            
            
                | 56 |  |  |         # Q: [3 B num_heads N C//self.num_heads] k.transpose(-2,-1)  K:[B num_heads C//self.num_heads N] | 
            
                                                                        
                            
            
                                    
            
            
                | 57 |  |  |         attn = (q @ k.transpose(-2, -1)) * self.scale  # B num_heads N N | 
            
                                                                        
                            
            
                                    
            
            
                | 58 |  |  |         attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1 | 
            
                                                                        
                            
            
                                    
            
            
                | 59 |  |  |         # 注意力权重对V进行加权求和 | 
            
                                                                        
                            
            
                                    
            
            
                | 60 |  |  |         # attn @ v : B num_heads N C//self.num_heads | 
            
                                                                        
                            
            
                                    
            
            
                | 61 |  |  |         # transpose: B N self.num_heads C//self.num_heads | 
            
                                                                        
                            
            
                                    
            
            
                | 62 |  |  |         # reshape将最后两个维度拼接,合并多个头的输出,回到总的嵌入维度 | 
            
                                                                        
                            
            
                                    
            
            
                | 63 |  |  |         x = (attn @ v).transpose(1, 2).reshape(B, N, C) | 
            
                                                                        
                            
            
                                    
            
            
                | 64 |  |  |         x = self.proj(x) | 
            
                                                                        
                            
            
                                    
            
            
                | 65 |  |  |         x = self.proj_drop(x) | 
            
                                                                        
                            
            
                                    
            
            
                | 66 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 67 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  | class Mlp(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         # in_features输入的维度, hidden_features隐藏层维度、通常为in_features的4倍, out_features输出维度、通常与输入维度相等 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         out_features = out_features or in_features | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |         hidden_features = hidden_features or in_features | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         self.fc1 = nn.Linear(in_features, hidden_features) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         self.act = act_layer() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         self.fc2 = nn.Linear(hidden_features, out_features) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         self.drop = nn.Dropout(drop) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         x = self.fc1(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         x = self.act(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         x = self.drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         x = self.fc2(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         x = self.drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  | class Block(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |     # mlp_ratio 计算hidden_features大小 默认为输入4倍   norm_layer正则化层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |     # drop_path_ratio 是drop_path的比率,该操作在残差连接之前  drop_ratio 是多头自注意力机制最后的linear后使用的dropout | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |     def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=False, qk_scale=None, drop_ratio=0., | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |                  attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         super(Block, self).__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         self.norm1 = norm_layer(dim)  # transformer encoder block中的第一个layer norm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         # 实例化多头注意力机制 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |                               atte_drop_ration=attn_drop_ratio, proj_drop_ration=drop_path_ratio) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0 else nn.Identity() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         self.norm2 = norm_layer(dim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         mlp_hidden_dim = int(dim * mlp_ratio)  # 计算MLP第一个全连接层的节点数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         x = x + self.drop_path(self.attn(self.norm1(x))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         x = x + self.drop_path(self.mlp(self.norm2(x))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  | class VisionTransformer(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |     def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |                  embed_dim=768, depth=12, num_heads=12,mlp_ratio=4., qkv_bias=True, qk_scale=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |                  representation_size=None, distilled=False, drop_ratio=0., | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |                  attn_drop_ratio=0., drop_path_ratio=0. , embed_layer=PatchEmbed ,norm_layer=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |                  act_layer=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |         super(VisionTransformer, self).__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         self.num_classes = num_classes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         self.num_features = self.embed_dim = embed_dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         self.num_tokens = 2 if distilled else 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 131 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 132 |  |  |  |