1 | from torch import nn |
||
2 | import torch |
||
3 | import matplotlib.pyplot as plt |
||
4 | |||
5 | """ |
||
6 | input shape:[batch, seq_len, d_model] |
||
7 | """ |
||
8 | |||
9 | class PositionEncoding(nn.Module): |
||
10 | def __init__(self, d_model, max_seq_len=512): |
||
11 | super().__init__() |
||
12 | # shape: [max_seq_len, 1] |
||
13 | position = torch.arange(0, max_seq_len).unsqueeze(1) |
||
14 | item = 1/10000 ** (torch.arange(0, d_model, 2)/d_model) |
||
15 | tmp_pos = position * item |
||
16 | pe = torch.zeros(max_seq_len, d_model) |
||
17 | pe[:, 0::2] = torch.sin(tmp_pos) |
||
18 | pe[:, 1::2] = torch.cos(tmp_pos) |
||
19 | # plt.matshow(pe) |
||
20 | # plt.show() 这两行用于可视化位置编码的图像 |
||
21 | pe = pe.unsqueeze(0) |
||
22 | self.register_buffer('pe', pe, False) |
||
23 | |||
24 | |||
25 | |||
26 | def forward(self, x): |
||
27 | batch, seq_len,_ = x.shape |
||
28 | pe = self.pe |
||
29 | return x + pe[:,:seq_len,:] |
||
30 | |||
31 | |||
32 | |||
33 | |||
34 | def attention(query, key, value, mask=None): |
||
35 | d_model = key.shape[-1] |
||
36 | # query, key, value shape:[batch, seq_len, d_model] |
||
37 | att_ = torch.matmul(query, key.transpose(-2, -1)) / d_model ** 0.5 |
||
38 | if mask is not None: |
||
39 | att_ = att_.masked_fill(mask, -1e9) |
||
40 | |||
41 | att_score = torch.softmax(att_, dim=-1) |
||
42 | return torch.matmul(att_score, value) |
||
43 | |||
44 | |||
45 | |||
46 | |||
47 | |||
48 | class MultiHeadAttention(nn.Module): |
||
49 | def __init__(self, heads, d_model, dropout=0.1): |
||
50 | super().__init__() |
||
51 | assert d_model % heads == 0 # 这里的做法是将不同的注意力头分治不同的qkv部分 |
||
52 | self.q_linear = nn.Linear(d_model, d_model, bias=False) |
||
53 | self.k_linear = nn.Linear(d_model, d_model, bias=False) |
||
54 | self.v_linear = nn.Linear(d_model, d_model, bias=False) |
||
55 | self.linear = nn.Linear(d_model, d_model, bias=False) |
||
56 | self.dropout = nn.Dropout(dropout) |
||
57 | self.heads = heads |
||
58 | self.d_k = d_model // heads |
||
59 | self.d_model = d_model |
||
60 | |||
61 | |||
62 | def forward(self, q, k, v, mask=None): |
||
63 | # [n, seq_len, d_model] -> [n, heads, seq_len, d_k] |
||
64 | # 这一步中,将输入x分布在三个linear中计算得到qkv,隐含了“w”矩阵 |
||
65 | q = self.q_linear(q).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
66 | k = self.k_linear(k).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
67 | v = self.v_linear(v).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
68 | out = attention(q, k, v, mask) |
||
69 | out = out.transpose(1,2).reshape(out.shape[0], -1, self.d_model) |
||
70 | out = self.linear(out) |
||
71 | out = self.dropout(out) |
||
72 | return out |
||
73 | |||
74 | |||
75 | |||
76 | |||
77 | class FeedForward(nn.Module): |
||
78 | def __init__(self, d_model, d_ff, dropout=0.1): |
||
79 | super().__init__() |
||
80 | self.ffn = nn.Sequential( |
||
81 | nn.Linear(d_model, d_ff, bias=False), |
||
82 | nn.ReLU(), |
||
83 | nn.Linear(d_ff, d_model, bias=False), |
||
84 | nn.Dropout(dropout) |
||
85 | ) |
||
86 | |||
87 | def forward(self, x): |
||
88 | return self.ffn(x) |
||
89 | |||
90 | |||
91 | |||
92 | |||
93 | class EncoderLayer(nn.Module): |
||
94 | def __init__(self, heads, d_model, d_ff, dropout=0.1): |
||
95 | super().__init__() |
||
96 | self.self_multi_head_att = MultiHeadAttention(heads, d_model, dropout) |
||
97 | self.ffn = FeedForward(d_model, d_ff, dropout) |
||
98 | self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(2)]) |
||
99 | self.dropout = nn.Dropout(dropout) |
||
100 | |||
101 | def forward(self, x, mask=None): |
||
102 | multi_head_att_out = self.self_multi_head_att(x, x, x, mask) |
||
103 | multi_head_att_out = self.norms[0](x + multi_head_att_out) |
||
104 | ffn_out = self.ffn(multi_head_att_out) |
||
105 | ffn_out = self.norms[1](multi_head_att_out + ffn_out) |
||
106 | out = self.dropout(ffn_out) |
||
107 | return out |
||
108 | |||
109 | |||
110 | View Code Duplication | class Encoder(nn.Module): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
111 | def __init__(self, vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
112 | super().__init__() |
||
113 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
||
114 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
||
115 | self.encoder_layers = nn.ModuleList([EncoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
||
116 | |||
117 | |||
118 | def forward(self, x, src_mask): |
||
119 | embed_x = self.embedding(x) |
||
120 | pos_encode_x = self.positional_encode(embed_x) |
||
121 | for layer in self.encoder_layers: |
||
122 | pos_encode_x = layer(pos_encode_x, src_mask) |
||
123 | return pos_encode_x |
||
124 | |||
125 | |||
126 | |||
127 | |||
128 | |||
129 | class DecoderLayer(nn.Module): |
||
130 | def __init__(self, heads, d_model, d_ff, dropout=0.1): |
||
131 | super().__init__() |
||
132 | self.masked_att = MultiHeadAttention(heads, d_model, dropout) |
||
133 | self.att = MultiHeadAttention(heads, d_model, dropout) |
||
134 | self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(3)]) |
||
135 | self.ffn = FeedForward(d_model, d_ff, dropout) |
||
136 | self.dropout = nn.Dropout(dropout) |
||
137 | |||
138 | def forward(self, x, encode_kv, dst_mask=None, src_dst_mask=None): |
||
139 | masked_att_out = self.masked_att(x, x, x, dst_mask) |
||
140 | masked_att_out = self.norms[0](x + masked_att_out) |
||
141 | att_out = self.att(masked_att_out, encode_kv, encode_kv, src_dst_mask) |
||
142 | att_out = self.norms[1](att_out + masked_att_out) |
||
143 | ffn_out = self.ffn(att_out) |
||
144 | ffn_out = self.norms[2](ffn_out + att_out) |
||
145 | out = self.dropout(ffn_out) |
||
146 | return out |
||
147 | |||
148 | |||
149 | |||
150 | |||
151 | |||
152 | View Code Duplication | class Decoder(nn.Module): |
|
0 ignored issues
–
show
|
|||
153 | def __init__(self,vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
154 | super().__init__() |
||
155 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
||
156 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
||
157 | self.decoder_layers = nn.ModuleList([DecoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
||
158 | |||
159 | def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None): |
||
160 | embed_x = self.embedding(x) |
||
161 | pos_encode_x = self.positional_encode(embed_x) |
||
162 | for layer in self.decoder_layers: |
||
163 | pos_encode_x = layer(pos_encode_x, encoder_kv, dst_mask, src_dst_mask) |
||
164 | return pos_encode_x |
||
165 | |||
166 | |||
167 | |||
168 | |||
169 | |||
170 | |||
171 | class Transformer(nn.Module): |
||
172 | def __init__(self, enc_vocab_size, dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
173 | super().__init__() |
||
174 | self.encoder = Encoder(enc_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout) |
||
175 | self.decoder = Decoder(dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout) |
||
176 | self.linear = nn.Linear(d_model, dec_vocab_size) |
||
177 | self.pad_idx = pad_idx |
||
178 | |||
179 | def generate_mask(self, query, key, is_triu_mask=False): # 最后一个参数用于判断是否是用于masked多头还是padding mask |
||
180 | device = query.device |
||
181 | # batch, seq_len |
||
182 | batch, seq_q = query.shape |
||
183 | _, seq_k = key.shape |
||
184 | # batch, head, seq_q, seq_k |
||
185 | mask = (key == self.pad_idx).unsqueeze(1).unsqueeze(2) |
||
186 | mask = mask.expand(batch, 1, seq_q, seq_k).to(device) |
||
187 | if is_triu_mask: |
||
188 | dst_triu_mask = torch.triu(torch.ones(seq_q, seq_k, dtype=torch.bool), diagonal=1) |
||
189 | dst_triu_mask = dst_triu_mask.unsqueeze(0).unsqueeze(1).expand(batch, 1, seq_q, seq_k).to(device) |
||
190 | return mask|dst_triu_mask |
||
191 | return mask |
||
192 | |||
193 | |||
194 | def forward(self, src, dst): |
||
195 | src_mask = self.generate_mask(src, src) # 输入部分的padding mask |
||
196 | encoder_out = self.encoder(src, src_mask) |
||
197 | dst_mask = self.generate_mask(dst, dst, True) |
||
198 | src_dst_mask = self.generate_mask(dst, src) |
||
199 | decoder_out = self.decoder(dst, encoder_out, dst_mask, src_dst_mask) |
||
200 | out = self.linear(decoder_out) |
||
201 | return out |
||
202 | |||
203 | |||
204 | |||
205 | if __name__ == '__main__': |
||
206 | # PositionEncoding(512, 100) 测试位置编码样式 |
||
207 | # att = MultiHeadAttention(8, 512, 0.2) 测试多头注意力的维度变化是否正确 |
||
208 | # x = torch.randn(4, 100, 512) |
||
209 | # out = att(x, x, x) |
||
210 | # print(out.shape) |
||
211 | att = Transformer(100, 200, 0, 512, 8, 6, 1024, 512, 0.1) |
||
212 | x = torch.randint(0, 100, (4, 64)) |
||
213 | y = torch.randint(0, 200, (4, 64)) |
||
214 | out = att(x, y) |
||
215 | print(out.shape) |