Issues (9)

1
from torch import nn
2
import torch
3
import matplotlib.pyplot as plt
4
5
"""
6
    input shape:[batch, seq_len, d_model]
7
"""
8
9
class PositionEncoding(nn.Module):
10
    def __init__(self, d_model, max_seq_len=512):
11
        super().__init__()
12
        # shape: [max_seq_len, 1]
13
        position = torch.arange(0, max_seq_len).unsqueeze(1)
14
        item = 1/10000 ** (torch.arange(0, d_model, 2)/d_model)
15
        tmp_pos = position * item
16
        pe = torch.zeros(max_seq_len, d_model)
17
        pe[:, 0::2] = torch.sin(tmp_pos)
18
        pe[:, 1::2] = torch.cos(tmp_pos)
19
        # plt.matshow(pe)
20
        # plt.show()  这两行用于可视化位置编码的图像
21
        pe = pe.unsqueeze(0)
22
        self.register_buffer('pe', pe, False)
23
24
25
26
    def forward(self, x):
27
        batch, seq_len,_ = x.shape
28
        pe = self.pe
29
        return x + pe[:,:seq_len,:]
30
31
32
33
34
def attention(query, key, value, mask=None):
35
    d_model = key.shape[-1]
36
    # query, key, value shape:[batch, seq_len, d_model]
37
    att_ = torch.matmul(query, key.transpose(-2, -1)) / d_model ** 0.5
38
    if mask is not None:
39
        att_ = att_.masked_fill(mask, -1e9)
40
41
    att_score = torch.softmax(att_, dim=-1)
42
    return torch.matmul(att_score, value)
43
44
45
46
47
48
class MultiHeadAttention(nn.Module):
49
    def __init__(self, heads, d_model, dropout=0.1):
50
        super().__init__()
51
        assert d_model % heads == 0   # 这里的做法是将不同的注意力头分治不同的qkv部分
52
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
53
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
54
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
55
        self.linear = nn.Linear(d_model, d_model, bias=False)
56
        self.dropout = nn.Dropout(dropout)
57
        self.heads = heads
58
        self.d_k = d_model // heads
59
        self.d_model = d_model
60
61
62
    def forward(self, q, k, v, mask=None):
63
        # [n, seq_len, d_model] -> [n, heads, seq_len, d_k]
64
        # 这一步中,将输入x分布在三个linear中计算得到qkv,隐含了“w”矩阵
65
        q = self.q_linear(q).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
66
        k = self.k_linear(k).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
67
        v = self.v_linear(v).reshape(q.shape[0], -1, self.heads, self.d_k).transpose(1, 2)
68
        out = attention(q, k, v, mask)
69
        out = out.transpose(1,2).reshape(out.shape[0], -1, self.d_model)
70
        out = self.linear(out)
71
        out = self.dropout(out)
72
        return out
73
74
75
76
77
class FeedForward(nn.Module):
78
    def __init__(self, d_model, d_ff, dropout=0.1):
79
        super().__init__()
80
        self.ffn = nn.Sequential(
81
            nn.Linear(d_model, d_ff, bias=False),
82
            nn.ReLU(),
83
            nn.Linear(d_ff, d_model, bias=False),
84
            nn.Dropout(dropout)
85
        )
86
87
    def forward(self, x):
88
        return self.ffn(x)
89
90
91
92
93
class EncoderLayer(nn.Module):
94
    def __init__(self, heads, d_model, d_ff, dropout=0.1):
95
        super().__init__()
96
        self.self_multi_head_att = MultiHeadAttention(heads, d_model, dropout)
97
        self.ffn = FeedForward(d_model, d_ff, dropout)
98
        self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(2)])
99
        self.dropout = nn.Dropout(dropout)
100
101
    def forward(self, x, mask=None):
102
        multi_head_att_out = self.self_multi_head_att(x, x, x, mask)
103
        multi_head_att_out = self.norms[0](x + multi_head_att_out)
104
        ffn_out = self.ffn(multi_head_att_out)
105
        ffn_out = self.norms[1](multi_head_att_out + ffn_out)
106
        out = self.dropout(ffn_out)
107
        return out
108
109
110 View Code Duplication
class Encoder(nn.Module):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
111
    def __init__(self, vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
112
        super().__init__()
113
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
114
        self.positional_encode = PositionEncoding(d_model, max_seq_len)
115
        self.encoder_layers = nn.ModuleList([EncoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)])
116
117
118
    def forward(self, x, src_mask):
119
        embed_x = self.embedding(x)
120
        pos_encode_x = self.positional_encode(embed_x)
121
        for layer in self.encoder_layers:
122
            pos_encode_x = layer(pos_encode_x, src_mask)
123
        return pos_encode_x
124
125
126
127
128
129
class DecoderLayer(nn.Module):
130
    def __init__(self, heads, d_model, d_ff, dropout=0.1):
131
        super().__init__()
132
        self.masked_att = MultiHeadAttention(heads, d_model, dropout)
133
        self.att = MultiHeadAttention(heads, d_model, dropout)
134
        self.norms = nn.ModuleList([nn.LayerNorm(d_model) for i in range(3)])
135
        self.ffn = FeedForward(d_model, d_ff, dropout)
136
        self.dropout = nn.Dropout(dropout)
137
138
    def forward(self, x, encode_kv, dst_mask=None, src_dst_mask=None):
139
        masked_att_out = self.masked_att(x, x, x, dst_mask)
140
        masked_att_out = self.norms[0](x + masked_att_out)
141
        att_out = self.att(masked_att_out, encode_kv, encode_kv, src_dst_mask)
142
        att_out = self.norms[1](att_out + masked_att_out)
143
        ffn_out = self.ffn(att_out)
144
        ffn_out = self.norms[2](ffn_out + att_out)
145
        out = self.dropout(ffn_out)
146
        return out
147
148
149
150
151
152 View Code Duplication
class Decoder(nn.Module):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
153
    def __init__(self,vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
154
        super().__init__()
155
        self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
156
        self.positional_encode = PositionEncoding(d_model, max_seq_len)
157
        self.decoder_layers = nn.ModuleList([DecoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)])
158
159
    def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None):
160
        embed_x = self.embedding(x)
161
        pos_encode_x = self.positional_encode(embed_x)
162
        for layer in self.decoder_layers:
163
            pos_encode_x = layer(pos_encode_x, encoder_kv, dst_mask, src_dst_mask)
164
        return pos_encode_x
165
166
167
168
169
170
171
class Transformer(nn.Module):
172
    def __init__(self, enc_vocab_size, dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1):
173
        super().__init__()
174
        self.encoder = Encoder(enc_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout)
175
        self.decoder = Decoder(dec_vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len, dropout)
176
        self.linear = nn.Linear(d_model, dec_vocab_size)
177
        self.pad_idx = pad_idx
178
179
    def generate_mask(self, query, key, is_triu_mask=False):    # 最后一个参数用于判断是否是用于masked多头还是padding mask
180
        device = query.device
181
        # batch, seq_len
182
        batch, seq_q = query.shape
183
        _, seq_k = key.shape
184
        # batch, head, seq_q, seq_k
185
        mask = (key == self.pad_idx).unsqueeze(1).unsqueeze(2)
186
        mask = mask.expand(batch, 1, seq_q, seq_k).to(device)
187
        if is_triu_mask:
188
            dst_triu_mask = torch.triu(torch.ones(seq_q, seq_k, dtype=torch.bool), diagonal=1)
189
            dst_triu_mask = dst_triu_mask.unsqueeze(0).unsqueeze(1).expand(batch, 1, seq_q, seq_k).to(device)
190
            return mask|dst_triu_mask
191
        return mask
192
193
194
    def forward(self, src, dst):
195
        src_mask = self.generate_mask(src, src)  # 输入部分的padding mask
196
        encoder_out = self.encoder(src, src_mask)
197
        dst_mask = self.generate_mask(dst, dst, True)
198
        src_dst_mask = self.generate_mask(dst, src)
199
        decoder_out = self.decoder(dst, encoder_out, dst_mask, src_dst_mask)
200
        out = self.linear(decoder_out)
201
        return out
202
203
204
205
if __name__ == '__main__':
206
    # PositionEncoding(512, 100) 测试位置编码样式
207
    # att = MultiHeadAttention(8, 512, 0.2) 测试多头注意力的维度变化是否正确
208
    # x = torch.randn(4, 100, 512)
209
    # out = att(x, x, x)
210
    # print(out.shape)
211
    att = Transformer(100, 200, 0, 512, 8, 6, 1024, 512, 0.1)
212
    x = torch.randint(0, 100, (4, 64))
213
    y = torch.randint(0, 200, (4, 64))
214
    out = att(x, y)
215
    print(out.shape)