1
|
|
|
#!/usr/bin/env python |
|
|
|
|
2
|
|
|
import os |
3
|
|
|
import time |
4
|
|
|
import argparse |
5
|
|
|
import logging |
6
|
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
7
|
|
|
tf.disable_v2_behavior() |
8
|
|
|
|
9
|
|
|
from e2edutch import util |
|
|
|
|
10
|
|
|
from e2edutch import coref_model as cm |
|
|
|
|
11
|
|
|
|
12
|
|
|
|
13
|
|
|
def get_parser(): |
|
|
|
|
14
|
|
|
parser = argparse.ArgumentParser() |
15
|
|
|
parser.add_argument('config') |
16
|
|
|
parser.add_argument('--train', |
17
|
|
|
type=str, |
18
|
|
|
default=None, |
19
|
|
|
help="jsonlines file used for training") |
20
|
|
|
parser.add_argument('--eval', |
21
|
|
|
type=str, |
22
|
|
|
default=None, |
23
|
|
|
help="jsonlines file used for evaluating") |
24
|
|
|
parser.add_argument('--eval_conll', |
25
|
|
|
type=str, |
26
|
|
|
default=None, |
27
|
|
|
help="conll file used for evaluating") |
28
|
|
|
parser.add_argument('--cfg_file', |
29
|
|
|
type=str, |
30
|
|
|
default=None, |
31
|
|
|
help="config file") |
32
|
|
|
parser.add_argument('--model_cfg_file', |
33
|
|
|
type=str, |
34
|
|
|
default=None, |
35
|
|
|
help="model config file") |
36
|
|
|
parser.add_argument('-v', '--verbose', action='store_true') |
37
|
|
|
return parser |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
def main(args=None): |
|
|
|
|
41
|
|
|
args = get_parser().parse_args() |
42
|
|
|
if args.verbose: |
43
|
|
|
logging.basicConfig(level=logging.DEBUG) |
44
|
|
|
config = util.initialize_from_env(args.config, |
45
|
|
|
args.cfg_file, |
46
|
|
|
args.model_cfg_file) |
47
|
|
|
# Overwrite train and eval file if specified |
48
|
|
|
if args.train is not None: |
49
|
|
|
config['train_path'] = args.train |
50
|
|
|
if args.eval is not None: |
51
|
|
|
config['eval_path'] = args.eval |
52
|
|
|
if args.eval_conll is not None: |
53
|
|
|
config['conll_eval_path'] = args.eval_conll |
54
|
|
|
|
55
|
|
|
report_frequency = config["report_frequency"] |
56
|
|
|
eval_frequency = config["eval_frequency"] |
57
|
|
|
|
58
|
|
|
model = cm.CorefModel(config) |
59
|
|
|
saver = tf.train.Saver() |
60
|
|
|
|
61
|
|
|
log_dir = os.path.join(config['log_root'], config['log_dir']) |
62
|
|
|
writer = tf.summary.FileWriter(log_dir, flush_secs=20) |
63
|
|
|
|
64
|
|
|
max_f1 = 0 |
65
|
|
|
|
66
|
|
|
with tf.Session() as session: |
67
|
|
|
session.run(tf.global_variables_initializer()) |
68
|
|
|
model.start_enqueue_thread(session) |
69
|
|
|
accumulated_loss = 0.0 |
70
|
|
|
|
71
|
|
|
ckpt = tf.train.get_checkpoint_state(log_dir) |
72
|
|
|
if ckpt and ckpt.model_checkpoint_path: |
73
|
|
|
print("Restoring from: {}".format(ckpt.model_checkpoint_path)) |
74
|
|
|
saver.restore(session, ckpt.model_checkpoint_path) |
75
|
|
|
|
76
|
|
|
initial_time = time.time() |
77
|
|
|
while True: |
78
|
|
|
tf_loss, tf_global_step, _ = session.run( |
79
|
|
|
[model.loss, model.global_step, model.train_op]) |
80
|
|
|
accumulated_loss += tf_loss |
81
|
|
|
|
82
|
|
|
if tf_global_step % report_frequency == 0: |
83
|
|
|
total_time = time.time() - initial_time |
84
|
|
|
steps_per_second = tf_global_step / total_time |
85
|
|
|
|
86
|
|
|
average_loss = accumulated_loss / report_frequency |
87
|
|
|
print("[{}] loss={:.2f}, steps/s={:.2f}" |
88
|
|
|
.format(tf_global_step, |
89
|
|
|
average_loss, |
90
|
|
|
steps_per_second)) |
91
|
|
|
writer.add_summary(util.make_summary( |
92
|
|
|
{"loss": average_loss}), tf_global_step) |
93
|
|
|
accumulated_loss = 0.0 |
94
|
|
|
|
95
|
|
|
if tf_global_step % eval_frequency == 0: |
96
|
|
|
saver.save(session, os.path.join(log_dir, "model"), |
97
|
|
|
global_step=tf_global_step) |
98
|
|
|
eval_summary, eval_f1 = model.evaluate(session) |
99
|
|
|
|
100
|
|
|
if eval_f1 > max_f1: |
101
|
|
|
max_f1 = eval_f1 |
102
|
|
|
util.copy_checkpoint(os.path.join( |
103
|
|
|
log_dir, "model-{}".format(tf_global_step)), |
104
|
|
|
os.path.join(log_dir, "model.max.ckpt")) |
105
|
|
|
|
106
|
|
|
writer.add_summary(eval_summary, tf_global_step) |
107
|
|
|
writer.add_summary(util.make_summary( |
108
|
|
|
{"max_eval_f1": max_f1}), tf_global_step) |
109
|
|
|
|
110
|
|
|
print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format( |
111
|
|
|
tf_global_step, eval_f1, max_f1)) |
112
|
|
|
|
113
|
|
|
|
114
|
|
|
if __name__ == "__main__": |
115
|
|
|
main() |
116
|
|
|
|