e2edutch.train   A
last analyzed

Complexity

Total Complexity 13

Size/Duplication

Total Lines 116
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 13
eloc 94
dl 0
loc 116
rs 10
c 0
b 0
f 0

2 Functions

Rating   Name   Duplication   Size   Complexity  
A get_parser() 0 25 1
D main() 0 72 12
1
#!/usr/bin/env python
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import os
3
import time
4
import argparse
5
import logging
6
import tensorflow.compat.v1 as tf
0 ignored issues
show
introduced by
Unable to import 'tensorflow.compat.v1'
Loading history...
7
tf.disable_v2_behavior()
8
9
from e2edutch import util
0 ignored issues
show
introduced by
Import "from e2edutch import util" should be placed at the top of the module
Loading history...
10
from e2edutch import coref_model as cm
0 ignored issues
show
introduced by
Import "from e2edutch import coref_model as cm" should be placed at the top of the module
Loading history...
11
12
13
def get_parser():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
14
    parser = argparse.ArgumentParser()
15
    parser.add_argument('config')
16
    parser.add_argument('--train',
17
                        type=str,
18
                        default=None,
19
                        help="jsonlines file used for training")
20
    parser.add_argument('--eval',
21
                        type=str,
22
                        default=None,
23
                        help="jsonlines file used for evaluating")
24
    parser.add_argument('--eval_conll',
25
                        type=str,
26
                        default=None,
27
                        help="conll file used for evaluating")
28
    parser.add_argument('--cfg_file',
29
                        type=str,
30
                        default=None,
31
                        help="config file")
32
    parser.add_argument('--model_cfg_file',
33
                        type=str,
34
                        default=None,
35
                        help="model config file")
36
    parser.add_argument('-v', '--verbose', action='store_true')
37
    return parser
38
39
40
def main(args=None):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Comprehensibility introduced by
This function exceeds the maximum number of variables (21/15).
Loading history...
41
    args = get_parser().parse_args()
42
    if args.verbose:
43
        logging.basicConfig(level=logging.DEBUG)
44
    config = util.initialize_from_env(args.config,
45
                                      args.cfg_file,
46
                                      args.model_cfg_file)
47
    # Overwrite train and eval file if specified
48
    if args.train is not None:
49
        config['train_path'] = args.train
50
    if args.eval is not None:
51
        config['eval_path'] = args.eval
52
    if args.eval_conll is not None:
53
        config['conll_eval_path'] = args.eval_conll
54
55
    report_frequency = config["report_frequency"]
56
    eval_frequency = config["eval_frequency"]
57
58
    model = cm.CorefModel(config)
59
    saver = tf.train.Saver()
60
61
    log_dir = os.path.join(config['log_root'], config['log_dir'])
62
    writer = tf.summary.FileWriter(log_dir, flush_secs=20)
63
64
    max_f1 = 0
65
66
    with tf.Session() as session:
67
        session.run(tf.global_variables_initializer())
68
        model.start_enqueue_thread(session)
69
        accumulated_loss = 0.0
70
71
        ckpt = tf.train.get_checkpoint_state(log_dir)
72
        if ckpt and ckpt.model_checkpoint_path:
73
            print("Restoring from: {}".format(ckpt.model_checkpoint_path))
74
            saver.restore(session, ckpt.model_checkpoint_path)
75
76
        initial_time = time.time()
77
        while True:
78
            tf_loss, tf_global_step, _ = session.run(
79
                [model.loss, model.global_step, model.train_op])
80
            accumulated_loss += tf_loss
81
82
            if tf_global_step % report_frequency == 0:
83
                total_time = time.time() - initial_time
84
                steps_per_second = tf_global_step / total_time
85
86
                average_loss = accumulated_loss / report_frequency
87
                print("[{}] loss={:.2f}, steps/s={:.2f}"
88
                      .format(tf_global_step,
89
                              average_loss,
90
                              steps_per_second))
91
                writer.add_summary(util.make_summary(
92
                    {"loss": average_loss}), tf_global_step)
93
                accumulated_loss = 0.0
94
95
            if tf_global_step % eval_frequency == 0:
96
                saver.save(session, os.path.join(log_dir, "model"),
97
                           global_step=tf_global_step)
98
                eval_summary, eval_f1 = model.evaluate(session)
99
100
                if eval_f1 > max_f1:
101
                    max_f1 = eval_f1
102
                    util.copy_checkpoint(os.path.join(
103
                        log_dir, "model-{}".format(tf_global_step)),
104
                                         os.path.join(log_dir, "model.max.ckpt"))
105
106
                writer.add_summary(eval_summary, tf_global_step)
107
                writer.add_summary(util.make_summary(
108
                    {"max_eval_f1": max_f1}), tf_global_step)
109
110
                print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format(
111
                    tf_global_step, eval_f1, max_f1))
112
113
114
if __name__ == "__main__":
115
    main()
116