1 | #!/usr/bin/env python |
||
2 | # -*- coding: utf-8 -*- |
||
3 | |||
4 | import os |
||
5 | import logging |
||
6 | from argparse import ArgumentParser |
||
7 | |||
8 | from utils import load_data |
||
9 | from lm import NeuralLM |
||
10 | from deepy.trainers import SGDTrainer, LearningRateAnnealer, AdamTrainer |
||
11 | from deepy.layers import LSTM |
||
12 | from layers import FullOutputLayer |
||
13 | |||
14 | |||
15 | logging.basicConfig(level=logging.INFO) |
||
16 | |||
17 | default_model = os.path.join(os.path.dirname(__file__), "models", "lstm_rnnlm.gz") |
||
18 | |||
19 | View Code Duplication | if __name__ == '__main__': |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
20 | ap = ArgumentParser() |
||
21 | ap.add_argument("--model", default="") |
||
22 | ap.add_argument("--small", action="store_true") |
||
23 | args = ap.parse_args() |
||
24 | |||
25 | vocab, lmdata = load_data(small=args.small, history_len=5, batch_size=64) |
||
26 | model = NeuralLM(vocab.size) |
||
27 | model.stack(LSTM(hidden_size=100, output_type="sequence", |
||
28 | persistent_state=True, batch_size=lmdata.size, |
||
29 | reset_state_for_input=0), |
||
30 | FullOutputLayer(vocab.size)) |
||
31 | |||
32 | if os.path.exists(args.model): |
||
33 | model.load_params(args.model) |
||
34 | |||
35 | trainer = SGDTrainer(model, {"learning_rate": LearningRateAnnealer.learning_rate(1.2), |
||
36 | "weight_l2": 1e-7}) |
||
37 | annealer = LearningRateAnnealer(trainer) |
||
38 | |||
39 | trainer.run(lmdata, controllers=[annealer]) |
||
40 | |||
41 | model.save_params(default_model) |
||
42 |