demo.make_predictions()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 16
rs 9.7
c 0
b 0
f 0
cc 1
nop 2
1
import logging
2
import argparse
3
import e2edutch.util
4
import e2edutch.coref_model as cm
5
import tensorflow.compat.v1 as tf
6
tf.disable_v2_behavior()
7
8
9
def print_predictions(example):
10
    words = e2edutch.util.flatten(example["sentences"])
11
    for cluster in example["predicted_clusters"]:
12
        print(u"Predicted cluster: {}".format(
13
            [" ".join(words[m[0]:m[1] + 1]) for m in cluster]))
14
15
16
def make_predictions(text, model):
17
    example = e2edutch.util.create_example(text)
18
    tensorized_example = model.tensorize_example(example, is_training=False)
19
    feed_dict = {i: t for i, t in zip(model.input_tensors, tensorized_example)}
20
    _, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run(
21
        model.predictions + [model.head_scores], feed_dict=feed_dict)
22
23
    predicted_antecedents = model.get_predicted_antecedents(
24
        antecedents, antecedent_scores)
25
26
    example["predicted_clusters"], _ = model.get_predicted_clusters(
27
        mention_starts, mention_ends, predicted_antecedents)
28
    example["top_spans"] = zip(
29
        (int(i) for i in mention_starts), (int(i) for i in mention_ends))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
30
    example["head_scores"] = head_scores.tolist()
31
    return example
32
33
34
def get_parser():
35
    parser = argparse.ArgumentParser()
36
    parser.add_argument('config')
37
    # , default=sys.stdin)
38
    parser.add_argument('input_file', type=argparse.FileType('r'))
39
    parser.add_argument('--cfg_file',
40
                        type=str,
41
                        default=None,
42
                        help="config file")
43
    parser.add_argument('--model_cfg_file',
44
                        type=str,
45
                        default=None,
46
                        help="model config file")
47
    parser.add_argument('-v', '--verbose', action='store_true')
48
    return parser
49
50
51
if __name__ == "__main__":
52
    args = get_parser().parse_args()
53
    if args.verbose:
54
        logging.basicConfig(level=logging.DEBUG)
55
    config = e2edutch.util.initialize_from_env(
56
        args.config, args.cfg_file, args.model_cfg_file)
57
    model = cm.CorefModel(config)
58
    with tf.Session() as session:
59
        model.restore(session)
60
        text = args.input_file.read()
61
        if len(text) > 0:
62
            print_predictions(make_predictions(text, model))
63