Passed
Push — master ( 37b2e6...90df96 )
by Jeremy
01:58 queued 15s
created

swin_transformer_model.WindowAttention.__init__()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 8
dl 0
loc 16
rs 9.75
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import torch
2
import torch.nn as nn
3
import torch.utils.checkpoint as checkpoint
4
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
6
7
class WindowAttention(nn.Module):
8
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
9
        super().__init__()
10
        self.dim = dim
11
        self.window_size = window_size
12
        self.num_heads = num_heads
13
        head_dim = dim // num_heads
14
        self.scale = qk_scale or head_dim ** -0.5
15
16
        self.relative_position_bias_table = nn.Parameter(
17
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
18
        )
19
20
        coords_h = torch.arange(self.window_size[0])
21
        coords_w = torch.arange(self.window_size[1])
22
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
23
        coords_flatten = torch.flatten(coords, 1)
24
25
26
27
28
##1.
29
class PatchEmbed(nn.Module):
30
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
31
        super().__init__()
32
        img_size = to_2tuple(img_size)  # 将输入转为长度为2的元组
33
        patch_size = to_2tuple(patch_size)
34
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
35
        self.img_size = img_size
36
        self.patch_size = patch_size
37
        self.num_patches = patches_resolution[0] * patches_resolution[1]
38
39
        self.in_chans = in_chans
40
        self.embed_dim = embed_dim
41
42
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
43
        if norm_layer is not None:
44
            self.norm = norm_layer(embed_dim)
45
        else:
46
            self.norm = None
47
48
    def forward(self, x):
49
        B, C, H, W = x.shape  # 1*3*224*224
50
        assert H == self.img_size[0] and W == self.img_size[1],\
51
            f"Input image size ({H} * {W}) does not match model ({self.img_size[0]} * {self.img_size[1]})."
52
        x = self.proj(x).flatten(2).transpose(1, 2)   # 1*3136*96
53
        if self.norm is not None:
54
            x = self.norm(x)
55
        return x