| 1 |  |  | import torch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | import torch.nn as nn | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import torch.utils.checkpoint as checkpoint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from transformers.models.clap.modeling_clap import window_partition | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     import os, sys | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     kernel_path = os.path.abspath(os.path.join('..')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     sys.path.append(kernel_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | except: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     WindowProcess = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     WindowProcessReverse = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  | class MLP(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |     def forward(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         return | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  | # 2.窗口自注意力机制 | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 31 |  |  | class WindowAttention(nn.Module):  # 注意力头随着层次不同要发生变化,来保证每个头处理的维度数不变 | 
            
                                                                        
                            
            
                                    
            
            
                | 32 |  |  |     def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | 
            
                                                                        
                            
            
                                    
            
            
                | 33 |  |  |         super().__init__() | 
            
                                                                        
                            
            
                                    
            
            
                | 34 |  |  |         self.dim = dim | 
            
                                                                        
                            
            
                                    
            
            
                | 35 |  |  |         self.window_size = window_size | 
            
                                                                        
                            
            
                                    
            
            
                | 36 |  |  |         self.num_heads = num_heads | 
            
                                                                        
                            
            
                                    
            
            
                | 37 |  |  |         head_dim = dim // num_heads | 
            
                                                                        
                            
            
                                    
            
            
                | 38 |  |  |         self.scale = qk_scale or head_dim ** -0.5 | 
            
                                                                        
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 40 |  |  |         # 这里是初始化相对位置编码的偏置表, 2m-1 * 2m-1是因为x,y的取值范围均为2m-1,排列组合有这些数量 | 
            
                                                                        
                            
            
                                    
            
            
                | 41 |  |  |         self.relative_position_bias_table = nn.Parameter( | 
            
                                                                        
                            
            
                                    
            
            
                | 42 |  |  |             torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) | 
            
                                                                        
                            
            
                                    
            
            
                | 43 |  |  |         ) | 
            
                                                                        
                            
            
                                    
            
            
                | 44 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 45 |  |  |         coords_h = torch.arange(self.window_size[0]) | 
            
                                                                        
                            
            
                                    
            
            
                | 46 |  |  |         coords_w = torch.arange(self.window_size[1]) | 
            
                                                                        
                            
            
                                    
            
            
                | 47 |  |  |         coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # meshgrid生成了二维的网格坐标,用stack函数拼接起来 | 
            
                                                                        
                            
            
                                    
            
            
                | 48 |  |  |         coords_flatten = torch.flatten(coords, 1)  # 将二维的相对位置索引先展平 | 
            
                                                                        
                            
            
                                    
            
            
                | 49 |  |  |         relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 利用广播机制, 2 * h*w * 1 - 2 * 1 * h*w, 得到了原始的相对位置索引信息 | 
            
                                                                        
                            
            
                                    
            
            
                | 50 |  |  |         relative_coords = relative_coords.permute(1, 2, 0).contiguous() | 
            
                                                                        
                            
            
                                    
            
            
                | 51 |  |  |         relative_coords[:, :, 0] += self.window_size[0] - 1 | 
            
                                                                        
                            
            
                                    
            
            
                | 52 |  |  |         relative_coords[:, :, 1] += self.window_size[1] - 1 | 
            
                                                                        
                            
            
                                    
            
            
                | 53 |  |  |         relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | 
            
                                                                        
                            
            
                                    
            
            
                | 54 |  |  |         relative_position_index = relative_coords.sum(-1)  # 得到最终的二维相对位置索引 | 
            
                                                                        
                            
            
                                    
            
            
                | 55 |  |  |         self.register_buffer("relative_position_index", relative_position_index)  # 注册相对位置索引不需要学习 | 
            
                                                                        
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 57 |  |  |         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | 
            
                                                                        
                            
            
                                    
            
            
                | 58 |  |  |         self.attn_drop = nn.Dropout(attn_drop) | 
            
                                                                        
                            
            
                                    
            
            
                | 59 |  |  |         self.proj = nn.Linear(dim, dim) | 
            
                                                                        
                            
            
                                    
            
            
                | 60 |  |  |         self.proj_drop = nn.Dropout(proj_drop) | 
            
                                                                        
                            
            
                                    
            
            
                | 61 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 62 |  |  |         # 对元素值进行截断正态分布初始化,将在分布外的值消去,有助于模型训练的稳定性 | 
            
                                                                        
                            
            
                                    
            
            
                | 63 |  |  |         trunc_normal_(self.relative_position_bias_table, std=.02) | 
            
                                                                        
                            
            
                                    
            
            
                | 64 |  |  |         self.softmax = nn.Softmax(dim=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |     def forward(self, x, mask=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         B_, N, C = x.shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         q, k, v = qkv[0], qkv[1], qkv[2] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         q = q * self.scale | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         attn = (q @ k.transpose(-2, -1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |             self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |         relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         attn = attn + relative_position_bias.unsqueeze(0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         if mask is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |             nW = mask.shape[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |             attn = attn.view(-1, self.num_heads, N, N) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |             attn = self.softmax(attn) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |             attn = self.softmax(attn) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         attn = self.attn_drop(attn) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         x = self.proj(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         x = self.proj_drop(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  | class SwinTransformerBlock(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |     def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |                  qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |                  fused_window_process=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         self.dim = dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         self.window_size = window_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         self.num_heads = num_heads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         self.input_resolution = input_resolution | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         self.shift_size = shift_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         self.mlp_ratio = mlp_ratio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         # 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         if min(self.input_resolution) <= self.window_size: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |             self.shift_size = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |             self.window_size = min(self.input_resolution) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         self.norm1 = norm_layer(dim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         self.attn = WindowAttention( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |             dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |             qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |         self.norm2 = norm_layer(dim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |         mlp_hidden_dim = int(dim * mlp_ratio) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         # 3.mask部分实现 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         if self.shift_size > 0: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |             H, W = self.input_resolution | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |             img_mask = torch.zeros((1, H, W, 1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |             h_slices = (slice(0, -self.window_size), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |                         slice(-self.window_size, -self.shift_size), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |                         slice(-self.shift_size, None)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |             w_slices = (slice(0, -self.window_size), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |                         slice(-self.window_size, -self.shift_size), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |                         slice(-self.shift_size, None))   # 相当于对输入切了三刀,可以参考示意图 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |             cnt = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |             # 给不同的区域上数据做标号,参考示意图 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |             for h in h_slices: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |                 for w in w_slices: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |                     img_mask[:, h, w, :] = cnt | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |                     cnt += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |             # 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |             mask_windows = window_partition(img_mask, self.window_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |             mask_windows = mask_windows.view(-1, self.window_size * self.window_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |             attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |             attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |             attn_mask = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         self.register_buffer("attn_mask", attn_mask) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         self.fused_window_process = fused_window_process | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         H, W = self.input_resolution | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         B, L, C = x.shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         assert L == H * W, "input feature has wrong size" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         shortcut = x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         x = self.norm1(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         x = x.view(B, H, W, C) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |         if self.shift_size > 0: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |             if not self.fused_window_process: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |                 shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |                 x_windows = window_partition(shifted_x, self.window_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |                 x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  | ##1. patch embedding 实现 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  | class PatchEmbed(nn.Module): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |     def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         super().__init__() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         img_size = to_2tuple(img_size)  # 将输入转为长度为2的元组 224*224 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |         patch_size = to_2tuple(patch_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |         self.img_size = img_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |         self.patch_size = patch_size | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         self.num_patches = patches_resolution[0] * patches_resolution[1] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         self.in_chans = in_chans | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         self.embed_dim = embed_dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |         self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |         if norm_layer is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |             self.norm = norm_layer(embed_dim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |             self.norm = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |     def forward(self, x): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |         B, C, H, W = x.shape  # 1*3*224*224 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         assert H == self.img_size[0] and W == self.img_size[1],\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |             f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})." | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |         x = self.proj(x).flatten(2).transpose(1, 2)   # 1*3136*96 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |         if self.norm is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |             x = self.norm(x) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |         return x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |     def flops(self):        # 统计浮点计算次数 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |         Ho, Wo = self.patches_resolution | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |         flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |         if self.norm is not None: | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 209 |  |  |             flops += Ho * Wo * self.embed_dim | 
            
                                                        
            
                                    
            
            
                | 210 |  |  |         return flops |