1
|
|
|
import logging |
2
|
|
|
import argparse |
3
|
|
|
import e2edutch.util |
4
|
|
|
import e2edutch.coref_model as cm |
5
|
|
|
import tensorflow.compat.v1 as tf |
6
|
|
|
tf.disable_v2_behavior() |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
def print_predictions(example): |
10
|
|
|
words = e2edutch.util.flatten(example["sentences"]) |
11
|
|
|
for cluster in example["predicted_clusters"]: |
12
|
|
|
print(u"Predicted cluster: {}".format( |
13
|
|
|
[" ".join(words[m[0]:m[1] + 1]) for m in cluster])) |
14
|
|
|
|
15
|
|
|
|
16
|
|
|
def make_predictions(text, model): |
17
|
|
|
example = e2edutch.util.create_example(text) |
18
|
|
|
tensorized_example = model.tensorize_example(example, is_training=False) |
19
|
|
|
feed_dict = {i: t for i, t in zip(model.input_tensors, tensorized_example)} |
20
|
|
|
_, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run( |
21
|
|
|
model.predictions + [model.head_scores], feed_dict=feed_dict) |
22
|
|
|
|
23
|
|
|
predicted_antecedents = model.get_predicted_antecedents( |
24
|
|
|
antecedents, antecedent_scores) |
25
|
|
|
|
26
|
|
|
example["predicted_clusters"], _ = model.get_predicted_clusters( |
27
|
|
|
mention_starts, mention_ends, predicted_antecedents) |
28
|
|
|
example["top_spans"] = zip( |
29
|
|
|
(int(i) for i in mention_starts), (int(i) for i in mention_ends)) |
|
|
|
|
30
|
|
|
example["head_scores"] = head_scores.tolist() |
31
|
|
|
return example |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
def get_parser(): |
35
|
|
|
parser = argparse.ArgumentParser() |
36
|
|
|
parser.add_argument('config') |
37
|
|
|
# , default=sys.stdin) |
38
|
|
|
parser.add_argument('input_file', type=argparse.FileType('r')) |
39
|
|
|
parser.add_argument('--cfg_file', |
40
|
|
|
type=str, |
41
|
|
|
default=None, |
42
|
|
|
help="config file") |
43
|
|
|
parser.add_argument('--model_cfg_file', |
44
|
|
|
type=str, |
45
|
|
|
default=None, |
46
|
|
|
help="model config file") |
47
|
|
|
parser.add_argument('-v', '--verbose', action='store_true') |
48
|
|
|
return parser |
49
|
|
|
|
50
|
|
|
|
51
|
|
|
if __name__ == "__main__": |
52
|
|
|
args = get_parser().parse_args() |
53
|
|
|
if args.verbose: |
54
|
|
|
logging.basicConfig(level=logging.DEBUG) |
55
|
|
|
config = e2edutch.util.initialize_from_env( |
56
|
|
|
args.config, args.cfg_file, args.model_cfg_file) |
57
|
|
|
model = cm.CorefModel(config) |
58
|
|
|
with tf.Session() as session: |
59
|
|
|
model.restore(session) |
60
|
|
|
text = args.input_file.read() |
61
|
|
|
if len(text) > 0: |
62
|
|
|
print_predictions(make_predictions(text, model)) |
63
|
|
|
|