Conditions | 6 |
Total Lines | 55 |
Code Lines | 43 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | import torch |
||
101 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, |
||
102 | qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
||
103 | fused_window_process=False): |
||
104 | super().__init__() |
||
105 | self.dim = dim |
||
106 | self.window_size = window_size |
||
107 | self.num_heads = num_heads |
||
108 | self.input_resolution = input_resolution |
||
109 | self.shift_size = shift_size |
||
110 | self.mlp_ratio = mlp_ratio |
||
111 | # 如果图片输入分辨率比窗口还小,就不用滑动窗口,并缩小窗口 |
||
112 | if min(self.input_resolution) <= self.window_size: |
||
113 | self.shift_size = 0 |
||
114 | self.window_size = min(self.input_resolution) |
||
115 | assert 0 <= self.shift_size <self.window_size, "shift_size must in 0-window_size" |
||
116 | |||
117 | self.norm1 = norm_layer(dim) |
||
118 | self.attn = WindowAttention( |
||
119 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, |
||
120 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop |
||
121 | ) |
||
122 | self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() |
||
123 | self.norm2 = norm_layer(dim) |
||
124 | mlp_hidden_dim = int(dim * mlp_ratio) |
||
125 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
||
126 | |||
127 | |||
128 | # 3.mask部分实现 |
||
129 | if self.shift_size > 0: |
||
130 | H, W = self.input_resolution |
||
131 | img_mask = torch.zeros((1, H, W, 1)) |
||
132 | h_slices = (slice(0, -self.window_size), |
||
133 | slice(-self.window_size, -self.shift_size), |
||
134 | slice(-self.shift_size, None)) |
||
135 | w_slices = (slice(0, -self.window_size), |
||
136 | slice(-self.window_size, -self.shift_size), |
||
137 | slice(-self.shift_size, None)) # 相当于对输入切了三刀,可以参考示意图 |
||
138 | |||
139 | cnt = 0 |
||
140 | # 给不同的区域上数据做标号,参考示意图 |
||
141 | for h in h_slices: |
||
142 | for w in w_slices: |
||
143 | img_mask[:, h, w, :] = cnt |
||
144 | cnt += 1 |
||
145 | |||
146 | # 利用将掩码矩阵展开相减,将不为0的部分填充为-100(说明这些地方不需要做attention,本身是距离很远不相关的) |
||
147 | mask_windows = window_partition(img_mask, self.window_size) |
||
148 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
||
149 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
||
150 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
||
151 | else: |
||
152 | attn_mask = None |
||
153 | |||
154 | self.register_buffer("attn_mask", attn_mask) |
||
155 | self.fused_window_process = fused_window_process |
||
156 | |||
210 | return flops |