1 | import os |
||
2 | import logging |
||
3 | from argparse import ArgumentParser |
||
4 | |||
5 | from vocab import Vocab |
||
6 | from lmdataset import LMDataset |
||
7 | from lm import NeuralLM |
||
8 | from deepy.dataset import SequentialMiniBatches |
||
9 | from deepy.trainers import SGDTrainer, LearningRateAnnealer |
||
10 | from deepy.layers import IRNN, Dense |
||
11 | |||
12 | |||
13 | logging.basicConfig(level=logging.INFO) |
||
14 | |||
15 | resource_dir = os.path.abspath(os.path.dirname(__file__)) + os.sep + "resources" |
||
16 | |||
17 | vocab_path = os.path.join(resource_dir, "ptb.train.txt") |
||
18 | train_path = os.path.join(resource_dir, "ptb.train.txt") |
||
19 | valid_path = os.path.join(resource_dir, "ptb.valid.txt") |
||
20 | vocab = Vocab(char_based=True) |
||
21 | vocab.load(vocab_path, max_size=1000) |
||
22 | |||
23 | model = NeuralLM(input_dim=vocab.size, input_tensor=3) |
||
24 | model.stack( |
||
25 | IRNN(hidden_size=100, output_type="sequence"), |
||
26 | Dense(vocab.size, "softmax")) |
||
27 | |||
28 | |||
29 | View Code Duplication | if __name__ == '__main__': |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
30 | ap = ArgumentParser() |
||
31 | ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "models", "char_irnn_model1.gz")) |
||
32 | ap.add_argument("--sample", default="") |
||
33 | args = ap.parse_args() |
||
34 | |||
35 | if os.path.exists(args.model): |
||
36 | model.load_params(args.model) |
||
37 | |||
38 | lmdata = LMDataset(vocab, train_path, valid_path, history_len=30, char_based=True, max_tokens=300) |
||
39 | batch = SequentialMiniBatches(lmdata, batch_size=20) |
||
40 | |||
41 | trainer = SGDTrainer(model) |
||
42 | annealer = LearningRateAnnealer() |
||
43 | |||
44 | trainer.run(batch, epoch_controllers=[annealer]) |
||
45 | |||
46 | model.save_params(args.model) |
||
47 |