1 | import os |
||
2 | import logging |
||
3 | from argparse import ArgumentParser |
||
4 | |||
5 | from vocab import Vocab |
||
6 | from lmdataset import LMDataset |
||
7 | from lm import NeuralLM |
||
8 | from deepy.dataset import SequentialMiniBatches |
||
9 | from deepy.trainers import SGDTrainer, LearningRateAnnealer |
||
10 | from deepy.layers import RNN, Dense |
||
11 | |||
12 | |||
13 | logging.basicConfig(level=logging.INFO) |
||
14 | |||
15 | resource_dir = os.path.abspath(os.path.dirname(__file__)) + os.sep + "resources" |
||
16 | |||
17 | vocab_path = os.path.join(resource_dir, "ptb.train.txt") |
||
18 | train_path = os.path.join(resource_dir, "ptb.train.txt") |
||
19 | valid_path = os.path.join(resource_dir, "ptb.valid.txt") |
||
20 | vocab = Vocab(char_based=True) |
||
21 | vocab.load(vocab_path, max_size=1000) |
||
22 | |||
23 | model = NeuralLM(input_dim=vocab.size, input_tensor=3) |
||
24 | model.stack( |
||
25 | RNN(hidden_size=100, output_type="sequence"), |
||
26 | RNN(hidden_size=100, output_type="sequence"), |
||
27 | Dense(vocab.size, "softmax")) |
||
28 | |||
29 | |||
30 | View Code Duplication | if __name__ == '__main__': |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
31 | ap = ArgumentParser() |
||
32 | ap.add_argument("--model", default=os.path.join(os.path.dirname(__file__), "models", "char_rnn_model1.gz")) |
||
33 | ap.add_argument("--sample", default="") |
||
34 | args = ap.parse_args() |
||
35 | |||
36 | if os.path.exists(args.model): |
||
37 | model.load_params(args.model) |
||
38 | |||
39 | lmdata = LMDataset(vocab, train_path, valid_path, history_len=30, char_based=True, max_tokens=300) |
||
40 | batch = SequentialMiniBatches(lmdata, batch_size=20) |
||
41 | |||
42 | trainer = SGDTrainer(model) |
||
43 | annealer = LearningRateAnnealer() |
||
44 | |||
45 | trainer.run(batch, controllers=[annealer]) |
||
46 | |||
47 | model.save_params(args.model) |
||
48 |