model.Decoder.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 5
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 9
dl 5
loc 5
rs 10
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from torch import nn
2
import torch
3
import matplotlib.pyplot as plt
4
5
"""
6
    input shape:[batch, seq_len, d_model]
7
"""
8
9
class PositionEncoding(nn.Module):
10
    def __init__(self, d_model, max_seq_len=512):
11
        super().__init__()
12
        # shape: [max_seq_len, 1]
13
        position = torch.arange(0, max_seq_len).unsqueeze(1)
14
        item = 1/10000 ** (torch.arange(0, d_model, 2)/d_model)
15
        tmp_pos = position * item
16
        pe = torch.zeros(max_seq_len, d_model)
17
        pe[:, 0::2] = torch.sin(tmp_pos)
18
        pe[:, 1::2] = torch.cos(tmp_pos)
19
        # plt.matshow(pe)
20
        # plt.show()  这两行用于可视化位置编码的图像
21
        pe = pe.unsqueeze(0)
22
        self.register_buffer('pe', pe, False)
23
24
25
26
    def forward(self, x):
27
        batch, seq_len,_ = x.shape
28
        pe = self.pe
29
        return x + pe[:,:seq_len,:]
30
31
32
33
34
def attention(query, key, value, mask=None):
35
    d_model = key.shape[-1]
36
    # query, key, value shape:[batch, seq_len, d_model]
37
    att_ = torch.matmul(query, key.transpose(-2, -1)) / d_model ** 0.5
38
    if mask is not None:
39
        att_ = att_.masked_fill(mask, -1e9)
40
41
    att_score = torch.softmax(att_, dim=-1)
42
    return torch.matmul(att_score, value)
43
44
45
46
47
48
class MultiHeadAttention(nn.Module):
49
    def __init__(self, heads, d_model, dropout=0.1):
50
        super().__init__()
51
        assert d_model % heads == 0   # 这里的做法是将不同的注意力头分治不同的qkv部分
52
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
53
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
54
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
55
        self.linear = nn.Linear(d_model, d_model, bias=False)
56
        self.dropout = nn.Dropout(dropout)
57
        self.heads = heads
58
        self.d_k = d_model // heads
59
        self.d_model = d_model
60
61
62
    def forward(self, q, k, v, mask=None):
63
        # [n, seq_len, d_model] -> [n, heads, seq_len, d_k]
64
        # 这一步中,将输入x分布在三个linear中计算得到qkv,隐含了“w”矩阵
65
        q = self.q_linear(q).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
66
        k = self.k_linear(k).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
67
        v = self.v_linear(v).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
68
        out = attention(q, k, v, mask)
69
        out = out.transpose(1,2).reshape(out.shape[0], -1, self.d_model)
70
        out = self.linear(out)
71
        out = self.dropout(out)
72
        return out
73
74
75
76
77
class FeedForward(nn.Module):
78
    def __init__(self, d_model, d_ff, dropout=0.1):
79
        super().__init__()
80
        self.ffn = nn.Sequential(
81
            nn.Linear(d_model, d_ff, bias=False),
82
            nn.ReLU(),
83
            nn.Linear(d_ff, d_model, bias=False),
84
            nn.Dropout(dropout)
85
        )
86
87
    def forward(self, x):
88
        return self.ffn(x)
89
90
91
92
93
class EncoderLayer(nn.Module):
94
    def __init__(self, heads, d_model, d_ff, dropout=0.1):
95
        super().__init__()
96
        self.self_multi_head_att = MultiHeadAttention(heads, d_model, dropout)
97
        self.ffn = FeedForward(d_model, d_ff, dropout)
98
        self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(2)])
99
        self.dropout = nn.Dropout(dropout)
100
101
    def forward(self, x, mask=None):
102
        multi_head_att_out = self.self_multi_head_att(x, x, x, mask)
103
        multi_head_att_out = self.norms[0](x + multi_head_att_out)
104
        ffn_out = self.ffn(multi_head_att_out)
105
        ffn_out = self.norms[1](multi_head_att_out + ffn_out)
106
        out = self.dropout(ffn_out)
107
        return out
108
109
110 View Code Duplication
class Encoder(nn.Module):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
111
    def __init__(self, vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
112
        super().__init__()
113
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
114
        self.positional_encode = PositionEncoding(d_model, max_seq_len)
115
        self.encoder_layers = nn.ModuleList([EncoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)])
116
117
118
    def forward(self, x, src_mask):
119
        embed_x = self.embedding(x)
120
        pos_encode_x = self.positional_encode(embed_x)
121
        for layer in self.encoder_layers:
122
            pos_encode_x = layer(pos_encode_x, src_mask)
123
        return pos_encode_x
124
125
126
127
128
129
class DecoderLayer(nn.Module):
130
    def __init__(self, heads, d_model, d_ff, dropout=0.1):
131
        super().__init__()
132
        self.masked_att = MultiHeadAttention(heads, d_model, dropout)
133
        self.att = MultiHeadAttention(heads, d_model, dropout)
134
        self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(3)])
135
        self.ffn = FeedForward(d_model, d_ff, dropout)
136
        self.dropout = nn.Dropout(dropout)
137
138
    def forward(self, x, encode_kv, dst_mask=None, src_dst_mask=None):
139
        masked_att_out = self.masked_att(x, x, x, dst_mask)
140
        masked_att_out = self.norms[0](x + masked_att_out)
141
        att_out = self.att(masked_att_out, encode_kv, encode_kv, src_dst_mask)
142
        att_out = self.norms[1](att_out + masked_att_out)
143
        ffn_out = self.ffn(att_out)
144
        ffn_out = self.norms[2](ffn_out + att_out)
145
        out = self.dropout(ffn_out)
146
        return out
147
148
149
150
151
152 View Code Duplication
class Decoder(nn.Module):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
153
    def __init__(self,vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
154
        super().__init__()
155
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
156
        self.positional_encode = PositionEncoding(d_model, max_seq_len)
157
        self.decoder_layers = nn.ModuleList([DecoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)])
158
159
    def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None):
160
        embed_x = self.embedding(x)
161
        pos_encode_x = self.positional_encode(embed_x)
162
        for layer in self.decoder_layers:
163
            pos_encode_x = layer(pos_encode_x, encoder_kv, dst_mask, src_dst_mask)
164
        return pos_encode_x
165
166
167
168
169
170
171
class Transformer(nn.Module):
172
    def __init__(self, enc_vocab_size, dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
173
        super().__init__()
174
        self.encoder = Encoder(enc_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout)
175
        self.decoder = Decoder(dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout)
176
        self.linear = nn.Linear(d_model, dec_vocab_size)
177
        self.pad_idx = pad_idx
178
179
    def generate_mask(self, query, key, is_triu_mask=False):    # 最后一个参数用于判断是否是用于masked多头还是padding mask
180
        device = query.device
181
        # batch, seq_len
182
        batch, seq_q = query.shape
183
        _, seq_k = key.shape
184
        # batch, head, seq_q, seq_k
185
        mask = (key == self.pad_idx).unsqueeze(1).unsqueeze(2)
186
        mask = mask.expand(batch, 1, seq_q, seq_k).to(device)
187
        if is_triu_mask:
188
            dst_triu_mask = torch.triu(torch.ones(seq_q, seq_k, dtype=torch.bool), diagonal=1)
189
            dst_triu_mask = dst_triu_mask.unsqueeze(0).unsqueeze(1).expand(batch, 1, seq_q, seq_k).to(device)
190
            return mask|dst_triu_mask
191
        return mask
192
193
194
    def forward(self, src, dst):
195
        src_mask = self.generate_mask(src, src)  # 输入部分的padding mask
196
        encoder_out = self.encoder(src, src_mask)
197
        dst_mask = self.generate_mask(dst, dst, True)
198
        src_dst_mask = self.generate_mask(dst, src)
199
        decoder_out = self.decoder(dst, encoder_out, dst_mask, src_dst_mask)
200
        out = self.linear(decoder_out)
201
        return out
202
203
204
205
if __name__ == '__main__':
206
    # PositionEncoding(512, 100) 测试位置编码样式
207
    # att = MultiHeadAttention(8, 512, 0.2) 测试多头注意力的维度变化是否正确
208
    # x = torch.randn(4, 100, 512)
209
    # out = att(x, x, x)
210
    # print(out.shape)
211
    att = Transformer(100, 200, 0, 512, 8, 6, 1024, 512, 0.1)
212
    x = torch.randint(0, 100, (4, 64))
213
    y = torch.randint(0, 200, (4, 64))
214
    out = att(x, y)
215
    print(out.shape)