@@ 110-123 (lines=14) @@ | ||
107 | return out |
|
108 | ||
109 | ||
110 | class Encoder(nn.Module): |
|
111 | def __init__(self, vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
|
112 | super().__init__() |
|
113 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
|
114 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
|
115 | self.encoder_layers = nn.ModuleList([EncoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
|
116 | ||
117 | ||
118 | def forward(self, x, src_mask): |
|
119 | embed_x = self.embedding(x) |
|
120 | pos_encode_x = self.positional_encode(embed_x) |
|
121 | for layer in self.encoder_layers: |
|
122 | pos_encode_x = layer(pos_encode_x, src_mask) |
|
123 | return pos_encode_x |
|
124 | ||
125 | ||
126 | ||
@@ 152-164 (lines=13) @@ | ||
149 | ||
150 | ||
151 | ||
152 | class Decoder(nn.Module): |
|
153 | def __init__(self,vocab_size, pad_idx, d_model, heads, num_layers, d_ff, max_seq_len=512, dropout=0.1): |
|
154 | super().__init__() |
|
155 | self.embedding = nn.Embedding(vocab_size, d_model, pad_idx) |
|
156 | self.positional_encode = PositionEncoding(d_model, max_seq_len) |
|
157 | self.decoder_layers = nn.ModuleList([DecoderLayer(heads, d_model, d_ff, dropout) for i in range(num_layers)]) |
|
158 | ||
159 | def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None): |
|
160 | embed_x = self.embedding(x) |
|
161 | pos_encode_x = self.positional_encode(embed_x) |
|
162 | for layer in self.decoder_layers: |
|
163 | pos_encode_x = layer(pos_encode_x, encoder_kv, dst_mask, src_dst_mask) |
|
164 | return pos_encode_x |
|
165 | ||
166 | ||
167 |