Passed
Push — master ( 37b2e6...90df96 )
by Jeremy
01:58 queued 15s
created

swin_transformer_model   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 56
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 56
rs 10
c 0
b 0
f 0
wmc 5

3 Methods

Rating   Name   Duplication   Size   Complexity  
A PatchEmbed.__init__() 0 17 2
A PatchEmbed.forward() 0 8 2
A WindowAttention.__init__() 0 16 1
1
import torch
2
import torch.nn as nn
3
import torch.utils.checkpoint as checkpoint
4
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
6
7
class WindowAttention(nn.Module):
8
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
9
        super().__init__()
10
        self.dim = dim
11
        self.window_size = window_size
12
        self.num_heads = num_heads
13
        head_dim = dim // num_heads
14
        self.scale = qk_scale or head_dim ** -0.5
15
16
        self.relative_position_bias_table = nn.Parameter(
17
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
18
        )
19
20
        coords_h = torch.arange(self.window_size[0])
21
        coords_w = torch.arange(self.window_size[1])
22
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
23
        coords_flatten = torch.flatten(coords, 1)
24
25
26
27
28
##1.
29
class PatchEmbed(nn.Module):
30
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
31
        super().__init__()
32
        img_size = to_2tuple(img_size)  # 将输入转为长度为2的元组
33
        patch_size = to_2tuple(patch_size)
34
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
35
        self.img_size = img_size
36
        self.patch_size = patch_size
37
        self.num_patches = patches_resolution[0] * patches_resolution[1]
38
39
        self.in_chans = in_chans
40
        self.embed_dim = embed_dim
41
42
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
43
        if norm_layer is not None:
44
            self.norm = norm_layer(embed_dim)
45
        else:
46
            self.norm = None
47
48
    def forward(self, x):
49
        B, C, H, W = x.shape  # 1*3*224*224
50
        assert H == self.img_size[0] and W == self.img_size[1],\
51
            f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})."
52
        x = self.proj(x).flatten(2).transpose(1, 2)   # 1*3136*96
53
        if self.norm is not None:
54
            x = self.norm(x)
55
        return x