|
1
|
|
|
import logging |
|
2
|
|
|
import argparse |
|
3
|
|
|
import e2edutch.util |
|
4
|
|
|
import e2edutch.coref_model as cm |
|
5
|
|
|
import tensorflow.compat.v1 as tf |
|
6
|
|
|
tf.disable_v2_behavior() |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
def print_predictions(example): |
|
10
|
|
|
words = e2edutch.util.flatten(example["sentences"]) |
|
11
|
|
|
for cluster in example["predicted_clusters"]: |
|
12
|
|
|
print(u"Predicted cluster: {}".format( |
|
13
|
|
|
[" ".join(words[m[0]:m[1] + 1]) for m in cluster])) |
|
14
|
|
|
|
|
15
|
|
|
|
|
16
|
|
|
def make_predictions(text, model): |
|
17
|
|
|
example = e2edutch.util.create_example(text) |
|
18
|
|
|
tensorized_example = model.tensorize_example(example, is_training=False) |
|
19
|
|
|
feed_dict = {i: t for i, t in zip(model.input_tensors, tensorized_example)} |
|
20
|
|
|
_, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run( |
|
21
|
|
|
model.predictions + [model.head_scores], feed_dict=feed_dict) |
|
22
|
|
|
|
|
23
|
|
|
predicted_antecedents = model.get_predicted_antecedents( |
|
24
|
|
|
antecedents, antecedent_scores) |
|
25
|
|
|
|
|
26
|
|
|
example["predicted_clusters"], _ = model.get_predicted_clusters( |
|
27
|
|
|
mention_starts, mention_ends, predicted_antecedents) |
|
28
|
|
|
example["top_spans"] = zip( |
|
29
|
|
|
(int(i) for i in mention_starts), (int(i) for i in mention_ends)) |
|
|
|
|
|
|
30
|
|
|
example["head_scores"] = head_scores.tolist() |
|
31
|
|
|
return example |
|
32
|
|
|
|
|
33
|
|
|
|
|
34
|
|
|
def get_parser(): |
|
35
|
|
|
parser = argparse.ArgumentParser() |
|
36
|
|
|
parser.add_argument('config') |
|
37
|
|
|
# , default=sys.stdin) |
|
38
|
|
|
parser.add_argument('input_file', type=argparse.FileType('r')) |
|
39
|
|
|
parser.add_argument('--cfg_file', |
|
40
|
|
|
type=str, |
|
41
|
|
|
default=None, |
|
42
|
|
|
help="config file") |
|
43
|
|
|
parser.add_argument('--model_cfg_file', |
|
44
|
|
|
type=str, |
|
45
|
|
|
default=None, |
|
46
|
|
|
help="model config file") |
|
47
|
|
|
parser.add_argument('-v', '--verbose', action='store_true') |
|
48
|
|
|
return parser |
|
49
|
|
|
|
|
50
|
|
|
|
|
51
|
|
|
if __name__ == "__main__": |
|
52
|
|
|
args = get_parser().parse_args() |
|
53
|
|
|
if args.verbose: |
|
54
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
55
|
|
|
config = e2edutch.util.initialize_from_env( |
|
56
|
|
|
args.config, args.cfg_file, args.model_cfg_file) |
|
57
|
|
|
model = cm.CorefModel(config) |
|
58
|
|
|
with tf.Session() as session: |
|
59
|
|
|
model.restore(session) |
|
60
|
|
|
text = args.input_file.read() |
|
61
|
|
|
if len(text) > 0: |
|
62
|
|
|
print_predictions(make_predictions(text, model)) |
|
63
|
|
|
|