QiuBiaoer /
transformer
| 1 | from torch import nn |
||
| 2 | import torch |
||
| 3 | import matplotlib.pyplot as plt |
||
| 4 | |||
| 5 | """ |
||
| 6 | input shape:[batch, seq_len, d_model] |
||
| 7 | """ |
||
| 8 | |||
| 9 | class PositionEncoding(nn.Module): |
||
| 10 | def __init__(self, d_model, max_seq_len=512): |
||
| 11 | super().__init__() |
||
| 12 | # shape: [max_seq_len, 1] |
||
| 13 | position = torch.arange(0, max_seq_len).unsqueeze(1) |
||
| 14 | item = 1/10000 ** (torch.arange(0, d_model, 2)/d_model) |
||
| 15 | tmp_pos = position * item |
||
| 16 | pe = torch.zeros(max_seq_len, d_model) |
||
| 17 | pe[:, 0::2] = torch.sin(tmp_pos) |
||
| 18 | pe[:, 1::2] = torch.cos(tmp_pos) |
||
| 19 | # plt.matshow(pe) |
||
| 20 | # plt.show() 这两行用于可视化位置编码的图像 |
||
| 21 | pe = pe.unsqueeze(0) |
||
| 22 | self.register_buffer('pe', pe, False) |
||
| 23 | |||
| 24 | |||
| 25 | |||
| 26 | def forward(self, x): |
||
| 27 | batch, seq_len,_ = x.shape |
||
| 28 | pe = self.pe |
||
| 29 | return x + pe[:,:seq_len,:] |
||
| 30 | |||
| 31 | |||
| 32 | |||
| 33 | |||
| 34 | def attention(query, key, value, mask=None): |
||
| 35 | d_model = key.shape[-1] |
||
| 36 | # query, key, value shape:[batch, seq_len, d_model] |
||
| 37 | att_ = torch.matmul(query, key.transpose(-2, -1)) / d_model ** 0.5 |
||
| 38 | if mask is not None: |
||
| 39 | att_ = att_.masked_fill(mask, -1e9) |
||
| 40 | |||
| 41 | att_score = torch.softmax(att_, dim=-1) |
||
| 42 | return torch.matmul(att_score, value) |
||
| 43 | |||
| 44 | |||
| 45 | |||
| 46 | |||
| 47 | |||
| 48 | class MultiHeadAttention(nn.Module): |
||
| 49 | def __init__(self, heads, d_model, dropout=0.1): |
||
| 50 | super().__init__() |
||
| 51 | assert d_model % heads == 0 # 这里的做法是将不同的注意力头分治不同的qkv部分 |
||
| 52 | self.q_linear = nn.Linear(d_model, d_model, bias=False) |
||
| 53 | self.k_linear = nn.Linear(d_model, d_model, bias=False) |
||
| 54 | self.v_linear = nn.Linear(d_model, d_model, bias=False) |
||
| 55 | self.linear = nn.Linear(d_model, d_model, bias=False) |
||
| 56 | self.dropout = nn.Dropout(dropout) |
||
| 57 | self.heads = heads |
||
| 58 | self.d_k = d_model // heads |
||
| 59 | self.d_model = d_model |
||
| 60 | |||
| 61 | |||
| 62 | def forward(self, q, k, v, mask=None): |
||
| 63 | # [n, seq_len, d_model] -> [n, heads, seq_len, d_k] |
||
| 64 | # 这一步中,将输入x分布在三个linear中计算得到qkv,隐含了“w”矩阵 |
||
| 65 | q = self.q_linear(q).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
| 66 | k = self.k_linear(k).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
| 67 | v = self.v_linear(v).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2) |
||
| 68 | out = attention(q, k, v, mask) |
||
| 69 | out = out.transpose(1,2).reshape(out.shape[0], -1, self.d_model) |
||
| 70 | out = self.linear(out) |
||
| 71 | out = self.dropout(out) |
||
| 72 | return out |
||
| 73 | |||
| 74 | |||
| 75 | |||
| 76 | |||
| 77 | class FeedForward(nn.Module): |
||
| 78 | def __init__(self, d_model, d_ff, dropout=0.1): |
||
| 79 | super().__init__() |
||
| 80 | self.ffn = nn.Sequential( |
||
| 81 | nn.Linear(d_model, d_ff, bias=False), |
||
| 82 | nn.ReLU(), |
||
| 83 | nn.Linear(d_ff, d_model, bias=False), |
||
| 84 | nn.Dropout(dropout) |
||
| 85 | ) |
||
| 86 | |||
| 87 | def forward(self, x): |
||
| 88 | return self.ffn(x) |
||
| 89 | |||
| 90 | |||
| 91 | |||
| 92 | |||
| 93 | class EncoderLayer(nn.Module): |
||
| 94 | def __init__(self, heads, d_model, d_ff, dropout=0.1): |
||
| 95 | super().__init__() |
||
| 96 | self.self_multi_head_att = MultiHeadAttention(heads, d_model, dropout) |
||
| 97 | self.ffn = FeedForward(d_model, d_ff, dropout) |
||
| 98 | self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(2)]) |
||
| 99 | self.dropout = nn.Dropout(dropout) |
||
| 100 | |||
| 101 | def forward(self, x, mask=None): |
||
| 102 | multi_head_att_out = self.self_multi_head_att(x, x, x, mask) |
||
| 103 | multi_head_att_out = self.norms[0](x + multi_head_att_out) |
||
| 104 | ffn_out = self.ffn(multi_head_att_out) |
||
| 105 | ffn_out = self.norms[1](multi_head_att_out + ffn_out) |
||
| 106 | out = self.dropout(ffn_out) |
||
| 107 | return out |
||
| 108 | |||
| 109 | |||
| 110 | View Code Duplication | class Encoder(nn.Module): |
|
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
| 111 | def __init__(self, vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
| 112 | super().__init__() |
||
| 113 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
||
| 114 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
||
| 115 | self.encoder_layers = nn.ModuleList([EncoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
||
| 116 | |||
| 117 | |||
| 118 | def forward(self, x, src_mask): |
||
| 119 | embed_x = self.embedding(x) |
||
| 120 | pos_encode_x = self.positional_encode(embed_x) |
||
| 121 | for layer in self.encoder_layers: |
||
| 122 | pos_encode_x = layer(pos_encode_x, src_mask) |
||
| 123 | return pos_encode_x |
||
| 124 | |||
| 125 | |||
| 126 | |||
| 127 | |||
| 128 | |||
| 129 | class DecoderLayer(nn.Module): |
||
| 130 | def __init__(self, heads, d_model, d_ff, dropout=0.1): |
||
| 131 | super().__init__() |
||
| 132 | self.masked_att = MultiHeadAttention(heads, d_model, dropout) |
||
| 133 | self.att = MultiHeadAttention(heads, d_model, dropout) |
||
| 134 | self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(3)]) |
||
| 135 | self.ffn = FeedForward(d_model, d_ff, dropout) |
||
| 136 | self.dropout = nn.Dropout(dropout) |
||
| 137 | |||
| 138 | def forward(self, x, encode_kv, dst_mask=None, src_dst_mask=None): |
||
| 139 | masked_att_out = self.masked_att(x, x, x, dst_mask) |
||
| 140 | masked_att_out = self.norms[0](x + masked_att_out) |
||
| 141 | att_out = self.att(masked_att_out, encode_kv, encode_kv, src_dst_mask) |
||
| 142 | att_out = self.norms[1](att_out + masked_att_out) |
||
| 143 | ffn_out = self.ffn(att_out) |
||
| 144 | ffn_out = self.norms[2](ffn_out + att_out) |
||
| 145 | out = self.dropout(ffn_out) |
||
| 146 | return out |
||
| 147 | |||
| 148 | |||
| 149 | |||
| 150 | |||
| 151 | |||
| 152 | View Code Duplication | class Decoder(nn.Module): |
|
|
0 ignored issues
–
show
|
|||
| 153 | def __init__(self,vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
| 154 | super().__init__() |
||
| 155 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
||
| 156 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
||
| 157 | self.decoder_layers = nn.ModuleList([DecoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
||
| 158 | |||
| 159 | def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None): |
||
| 160 | embed_x = self.embedding(x) |
||
| 161 | pos_encode_x = self.positional_encode(embed_x) |
||
| 162 | for layer in self.decoder_layers: |
||
| 163 | pos_encode_x = layer(pos_encode_x, encoder_kv, dst_mask, src_dst_mask) |
||
| 164 | return pos_encode_x |
||
| 165 | |||
| 166 | |||
| 167 | |||
| 168 | |||
| 169 | |||
| 170 | |||
| 171 | class Transformer(nn.Module): |
||
| 172 | def __init__(self, enc_vocab_size, dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
||
| 173 | super().__init__() |
||
| 174 | self.encoder = Encoder(enc_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout) |
||
| 175 | self.decoder = Decoder(dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout) |
||
| 176 | self.linear = nn.Linear(d_model, dec_vocab_size) |
||
| 177 | self.pad_idx = pad_idx |
||
| 178 | |||
| 179 | def generate_mask(self, query, key, is_triu_mask=False): # 最后一个参数用于判断是否是用于masked多头还是padding mask |
||
| 180 | device = query.device |
||
| 181 | # batch, seq_len |
||
| 182 | batch, seq_q = query.shape |
||
| 183 | _, seq_k = key.shape |
||
| 184 | # batch, head, seq_q, seq_k |
||
| 185 | mask = (key == self.pad_idx).unsqueeze(1).unsqueeze(2) |
||
| 186 | mask = mask.expand(batch, 1, seq_q, seq_k).to(device) |
||
| 187 | if is_triu_mask: |
||
| 188 | dst_triu_mask = torch.triu(torch.ones(seq_q, seq_k, dtype=torch.bool), diagonal=1) |
||
| 189 | dst_triu_mask = dst_triu_mask.unsqueeze(0).unsqueeze(1).expand(batch, 1, seq_q, seq_k).to(device) |
||
| 190 | return mask|dst_triu_mask |
||
| 191 | return mask |
||
| 192 | |||
| 193 | |||
| 194 | def forward(self, src, dst): |
||
| 195 | src_mask = self.generate_mask(src, src) # 输入部分的padding mask |
||
| 196 | encoder_out = self.encoder(src, src_mask) |
||
| 197 | dst_mask = self.generate_mask(dst, dst, True) |
||
| 198 | src_dst_mask = self.generate_mask(dst, src) |
||
| 199 | decoder_out = self.decoder(dst, encoder_out, dst_mask, src_dst_mask) |
||
| 200 | out = self.linear(decoder_out) |
||
| 201 | return out |
||
| 202 | |||
| 203 | |||
| 204 | |||
| 205 | if __name__ == '__main__': |
||
| 206 | # PositionEncoding(512, 100) 测试位置编码样式 |
||
| 207 | # att = MultiHeadAttention(8, 512, 0.2) 测试多头注意力的维度变化是否正确 |
||
| 208 | # x = torch.randn(4, 100, 512) |
||
| 209 | # out = att(x, x, x) |
||
| 210 | # print(out.shape) |
||
| 211 | att = Transformer(100, 200, 0, 512, 8, 6, 1024, 512, 0.1) |
||
| 212 | x = torch.randint(0, 100, (4, 64)) |
||
| 213 | y = torch.randint(0, 200, (4, 64)) |
||
| 214 | out = att(x, y) |
||
| 215 | print(out.shape) |