| 1 |  |  | from collections import OrderedDict | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from functools import partial | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import torch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import torch.nn as nn | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | def drop_path(x, drop_prob: float = 0., training: bool = False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |     Drop paths(随机深度)每个样本(在残差块的主路径中应用时)。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     这个实现类似于 DropConnect,用于 EfficientNet 等网络,但名字不同,DropConnect 是另一种形式的 dropout。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     链接中有详细的讨论:https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |     我们使用 'drop path' 而不是 'DropConnect' 来避免混淆,并将参数名用 'survival rate' 来代替。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     参数: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     - x: 输入张量。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     - drop_prob: 丢弃路径的概率。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |     - training: 是否处于训练模式。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |     返回: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     - 如果不在训练模式或丢弃概率为 0,返回输入张量 x; | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     - 否则,返回经过丢弃操作后的张量。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     if drop_prob == 0. or not training:  # 如果丢弃概率为 0 或不处于训练模式,直接返回原始输入 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |     keep_prob = 1 - drop_prob  # 保持路径的概率 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 生成与 x 的维度匹配的形状,只保持 batch 维度 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |     random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)  # 生成一个与 x 大小相同的随机张量 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     random_tensor.floor_()  # 将随机张量二值化(比如keep_prob若值为0.7,则在前一步0-1的随机分布中大于等于0.3的在这里都变为1,小于0.3的则成为0,即有70%的值被保留下来) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |     output = x.div(keep_prob) * random_tensor  # 将输入 x 缩放(目的是放大保留下来的部分,以补偿因丢弃部分神经元而导致的输出总期望值的下降)并与随机张量相乘,实现部分路径的丢弃 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     return output  # 返回经过 drop path 操作后的张量 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  | class DropPath(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     Drop paths(随机深度)每个样本(在残差块的主路径中应用时)。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |     这是一个 PyTorch 模块,用于在训练期间随机丢弃某些路径,以增强模型的泛化能力。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     def __init__(self, drop_prob=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         super(DropPath, self).__init__()  # 调用父类 nn.Module 的构造函数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         self.drop_prob = drop_prob  # 初始化丢弃概率 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         前向传播函数,调用 drop_path 函数。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         参数: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         - x: 输入张量。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         返回: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         - 经过 drop path 操作后的张量。 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |         return drop_path(x, self.drop_prob, self.training)  # 调用上面定义的 drop_path 函数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  | class PatchEmbed(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         # img_size图像大小   patch_size每个图像块patch的大小  in_c 输入通道  embed_dim 嵌入维度  norm_layer 可选的归一化层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         img_size = (img_size, img_size)   # 将输入图像大小变为二维元组 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         patch_size = (patch_size, patch_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         self.img_size = img_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         self.patch_size = patch_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  # 224/16, 224/16  以patch为单位形成的新“图像”尺寸 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         self.num_patches = self.grid_size[0] * self.grid_size[1]  # 14*14=196 patch总数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 利用一个卷积核为16*16,步长为16大小进行卷积操作来等效实现将原图拆分成patch   B, 3, 224, 224 -> B, 768, 14, 14 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()  # 若存在norm layer则使用,否则保持不变 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         B, C, H, W = x.shape   # 获取输入张量的形状 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         assert H == self.img_size[0] and W == self.img_size[1],\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         f"输入图像大小{H} * {W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         # B, 3, 224, 224 -> B, 768, 14, 14 -> B, 768, 196 -> B, 196, 768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |         x = self.proj(x).flatten(2).transpose(1, 2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         x = self.norm(x) # 使用norm层进行归一化 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  | class Attention(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |     # dim输入的token维度768, num_heads注意力头数,qkv_bias生成QKV的时候是否添加偏置, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |     # qk_scale用于缩放QK的缩放因子,若为None,则使用1/sqrt(embed_dim_pre_head) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |     # atte_drop_ration注意力分数的dropout的比率,防止过拟合  proj_drop_ration最终投影层的dropout的比率 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, atte_drop_ration=0., proj_drop_ration=0.): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         self.num_heads = num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         head_dim = dim // num_heads # 每个注意力头的维度 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         self.scale = qk_scale or head_dim ** -0.5  # qk的缩放因子 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 通过全连接层生成QKV,为了并行运算提高计算效率,同时参数更少 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         self.attn_drop = nn.Dropout(atte_drop_ration) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         self.proj_drop = nn.Dropout(proj_drop_ration) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         # 将每个head得到的输出进行concat拼接,然后通过线性变换映射回原本的嵌入dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         self.proj = nn.Linear(dim, dim, bias=qkv_bias) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         B, N, C = x.shape  # B为batch,N为num_patch+1,C为embed_dim  +1为clstoken | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         #  B N 3*C -> B N 3 num_heads, C//self.num_heads -> 3 B num_heads N C//self.num_heads  作用是方便之后的运算 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         qkv = self.qkv(x).reshape(B,N,3,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         # 用切片拿到QKV,形状是 B num_heads N C//self.num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         q, k, v = qkv[0], qkv[1], qkv[2] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         # 计算qk的点积,并进行缩放得到注意力分数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         # Q: [3 B num_heads N C//self.num_heads] k.transpose(-2,-1)  K:[B num_heads C//self.num_heads N] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         attn = (q @ k.transpose(-2, -1)) * self.scale  # B num_heads N N | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         # 注意力权重对V进行加权求和 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         # attn @ v : B num_heads N C//self.num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         # transpose: B N self.num_heads C//self.num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         # reshape将最后两个维度拼接,合并多个头的输出,回到总的嵌入维度 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         x = (attn @ v).transpose(1, 2).reshape(B, N, C) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         x = self.proj(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         x = self.proj_drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  | class Mlp(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |         # in_features输入的维度, hidden_features隐藏层维度、通常为in_features的4倍, out_features输出维度、通常与输入维度相等 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         out_features = out_features or in_features | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         hidden_features = hidden_features or in_features | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         self.fc1 = nn.Linear(in_features, hidden_features) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         self.act = act_layer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         self.fc2 = nn.Linear(hidden_features, out_features) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         self.drop = nn.Dropout(drop) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         x = self.fc1(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         x = self.act(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |         x = self.drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         x = self.fc2(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         x = self.drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  | class Block(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |     # mlp_ratio 计算hidden_features大小 默认为输入4倍   norm_layer正则化层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |     # drop_path_ratio 是drop_path的比率,该操作在残差连接之前  drop_ratio 是多头自注意力机制最后的linear后使用的dropout | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |     def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_ratio=0., | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |                  attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         super(Block, self).__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         self.norm1 = norm_layer(dim)  # transformer encoder block中的第一个layer norm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |         # 实例化多头注意力机制 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |                               atte_drop_ration=attn_drop_ratio, proj_drop_ration=drop_path_ratio) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |         self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0 else nn.Identity() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |         self.norm2 = norm_layer(dim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         mlp_hidden_dim = int(dim * mlp_ratio)  # 计算MLP第一个全连接层的节点数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 160 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 161 |  |  |     def forward(self, x): | 
            
                                                                        
                            
            
                                    
            
            
                | 162 |  |  |         x = x + self.drop_path(self.attn(self.norm1(x))) | 
            
                                                                        
                            
            
                                    
            
            
                | 163 |  |  |         x = x + self.drop_path(self.mlp(self.norm2(x))) | 
            
                                                                        
                            
            
                                    
            
            
                | 164 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  | class VisionTransformer(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |     def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |                  embed_dim=768, depth=12, num_heads=12,mlp_ratio=4., qkv_bias=True, qk_scale=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |                  representation_size=None, distilled=False, drop_ratio=0., | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |                  attn_drop_ratio=0., drop_path_ratio=0. , embed_layer=PatchEmbed ,norm_layer=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |                  act_layer=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |         super(VisionTransformer, self).__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |         self.num_classes = num_classes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |         self.num_features = self.embed_dim = embed_dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |         self.num_tokens = 2 if distilled else 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |         # 设置一个较小的参数防止除0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         act_layer = act_layer or nn.GELU() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |         self.patch_embed = embed_layer(img_size, patch_size, in_c, embed_dim, norm_layer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         num_patches = self.patch_embed.num_patches  # 得到patches的个数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |         # 使用nn.Parameter构建可训练的参数,用零矩阵初始化,第一个为batch,后两个为1*768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |         # pos_embed 大小与concat拼接后的大小一致,是197*768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         self.pos_drop = nn.Dropout(drop_ratio) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |         # 根据传入的drop_path_ratio 构建等差序列,从0到drop_path_ratio,有depth个元素 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |         dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |         # 使用nn.Sequential将列表中的所有模块打包为一个整体 depth对应的是使用了transformer encoder block的数量 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         self.block = nn.Sequential(*[ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |             Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |                   drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |                   norm_layer=norm_layer, act_layer=act_layer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |             for i in range(depth) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |         ]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         self.norm = norm_layer(embed_dim)  # 通过transformer后的layer norm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |             这段代码中logits层是作为模型最后一层的原始输出值(一般是全连接层,尚未经过归一化),一般需要通过激活函数得到统计概率作为最终输出 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |             这里的representation size指的是你想要的输出数据的尺寸大小  在小规模的ViT中不需要该参数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |         if representation_size and not distilled: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |             self.has_logits = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |             self.num_features = representation_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |             self.pre_logits = nn.Sequential(OrderedDict([ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |                 ("fc", nn.Linear(embed_dim, representation_size)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |                 ("act", nn.Tanh()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |             ])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |             self.has_logits = False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |             self.pre_logits = nn.Identity() # 不做任何处理 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |         # 分类头 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |         self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |         self.head_dist = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         if distilled: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |             self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         # 权重初始化 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |         nn.init.trunc_normal_(self.pos_embed, std=0.02) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |         if self.dist_token is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |             nn.init.trunc_normal_(self.dist_token, std=0.02) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |         nn.init.trunc_normal_(self.cls_token, std=0.02) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         self.apply(_init_vit_weights) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |     def forward_features(self, x):  # 针对patch embedding部分的forward | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         # B C H W -> B num_patches embed_dim  196 * 768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |         x = self.patch_embed(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |         # 1, 1, 768 -> B, 1, 768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |         cls_token = self.cls_token.expand(x.shape[0], -1, -1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |         # dist_token存在, 则拼接dist_token和cls_token, 否则只拼接cls_token和输入的patch特征x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |         if self.dist_token is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |             x = torch.cat((cls_token, x), dim=1) # B 197 768 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |             x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1),x), dim=1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |         x = self.pos_drop(x+self.pos_embed) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |         x = self.block(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |         x = self.norm(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |         if self.dist_token is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |             return self.pre_logits(x[:, 0])  # dist_token为None,利用切片的形式获取cls_token对应的输出 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |             return x[:, 0], x[:, 1:] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |         x = self.forward_features(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |         if self.head_dist is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |             # 知识蒸馏相关知识 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |             x, x_dist = self.head(x[0]), self.head_dist(x[1]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |             # 如果是训练模式且不是脚本模式 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |             if self.training and not torch.jit.is_scripting(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |                 # 则返回两个头部的预测结果 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |                 return x, x_dist | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |             x = self.head(x) # 最后的linear全连接层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  | def _init_vit_weights(m): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |     # 判断模块m是否为线形层 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |     if isinstance(m, nn.Linear): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |         nn.init.trunc_normal_(m.weight, std=0.01) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |         if m.bias is not None: # 如果线性层存在偏置项,则将偏置项初始化为0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |             nn.init.zeros_(m.bias) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |     elif isinstance(m, nn.Conv2d): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |         nn.init.kaiming_normal_(m.weight, mode='fan_out') # 对卷积层的权重做一个初始化,适用于卷积 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |         if m.bias is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |             nn.init.zeros_(m.bias) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |     elif isinstance(m, nn.LayerNorm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |         nn.init.zeros_(m.bias) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |         nn.init.ones_(m.weight)  # 对层归一化的权重初始化为1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  | def vit_base_patch16_224(num_classes:int = 1000, pretrained=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |     model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  |                               representation_size=None, num_classes=num_classes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |     return model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 279 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 280 |  |  |  |