1
|
|
|
import torch |
2
|
|
|
import torch.nn as nn |
3
|
|
|
import torch.utils.checkpoint as checkpoint |
4
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ |
5
|
|
|
|
6
|
|
|
|
7
|
|
|
class WindowAttention(nn.Module): |
8
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
9
|
|
|
super().__init__() |
10
|
|
|
self.dim = dim |
11
|
|
|
self.window_size = window_size |
12
|
|
|
self.num_heads = num_heads |
13
|
|
|
head_dim = dim // num_heads |
14
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
15
|
|
|
|
16
|
|
|
self.relative_position_bias_table = nn.Parameter( |
17
|
|
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) |
18
|
|
|
) |
19
|
|
|
|
20
|
|
|
coords_h = torch.arange(self.window_size[0]) |
21
|
|
|
coords_w = torch.arange(self.window_size[1]) |
22
|
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
23
|
|
|
coords_flatten = torch.flatten(coords, 1) |
24
|
|
|
|
25
|
|
|
|
26
|
|
|
|
27
|
|
|
|
28
|
|
|
##1. |
29
|
|
|
class PatchEmbed(nn.Module): |
30
|
|
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
31
|
|
|
super().__init__() |
32
|
|
|
img_size = to_2tuple(img_size) # 将输入转为长度为2的元组 |
33
|
|
|
patch_size = to_2tuple(patch_size) |
34
|
|
|
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
35
|
|
|
self.img_size = img_size |
36
|
|
|
self.patch_size = patch_size |
37
|
|
|
self.num_patches = patches_resolution[0] * patches_resolution[1] |
38
|
|
|
|
39
|
|
|
self.in_chans = in_chans |
40
|
|
|
self.embed_dim = embed_dim |
41
|
|
|
|
42
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
43
|
|
|
if norm_layer is not None: |
44
|
|
|
self.norm = norm_layer(embed_dim) |
45
|
|
|
else: |
46
|
|
|
self.norm = None |
47
|
|
|
|
48
|
|
|
def forward(self, x): |
49
|
|
|
B, C, H, W = x.shape # 1*3*224*224 |
50
|
|
|
assert H == self.img_size[0] and W == self.img_size[1],\ |
51
|
|
|
f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})." |
52
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) # 1*3136*96 |
53
|
|
|
if self.norm is not None: |
54
|
|
|
x = self.norm(x) |
55
|
|
|
return x |