Completed
Push — master ( fc9889...278673 )
by Dafne van
18s queued 16s
created

e2edutch.predict   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 161
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 19
eloc 116
dl 0
loc 161
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A Predictor.__init__() 0 5 1
A Predictor.predict() 0 28 1
A Predictor.end_session() 0 3 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A get_parser() 0 15 1
A read_jsonlines() 0 4 2
D main() 0 75 13
1
import sys
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import json
3
import os
4
import io
5
import collections
6
import argparse
7
import logging
8
9
from e2edutch import conll
10
from e2edutch import minimize
11
from e2edutch import util
12
from e2edutch import coref_model as cm
13
from e2edutch import naf
14
15
import tensorflow.compat.v1 as tf
0 ignored issues
show
introduced by
Unable to import 'tensorflow.compat.v1'
Loading history...
introduced by
third party import "import tensorflow.compat.v1 as tf" should be placed before "from e2edutch import conll"
Loading history...
16
17
18
class Predictor(object):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
introduced by
Class 'Predictor' inherits from object, can be safely removed from bases in python3
Loading history...
19
    def __init__(self, model_name='final', cfg_file=None):
20
        self.config = util.initialize_from_env(model_name, cfg_file)
21
        self.session = tf.compat.v1.Session()
22
        self.model = cm.CorefModel(self.config)
23
        self.model.restore(self.session)
24
25
    def predict(self, example):
26
        """
27
        Predict coreference spans for a tokenized text.
28
29
30
        Args:
31
            example (dict): dict with the following fields:
32
                              sentences ([[str]])
33
                              doc_id (str)
34
                              clusters ([[(int, int)]]) (optional)
35
36
        Returns:
37
            [[(int, int)]]: a list of clusters. The items of the cluster are
38
                            spans, denoted by their start end end token index
39
40
        """
41
        tensorized_example = self.model.tensorize_example(
42
            example, is_training=False)
43
        feed_dict = {i: t for i, t in zip(
0 ignored issues
show
Unused Code introduced by
Unnecessary use of a comprehension
Loading history...
44
            self.model.input_tensors, tensorized_example)}
45
        _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = self.session.run(
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
46
            self.model.predictions, feed_dict=feed_dict)
47
        predicted_antecedents = self.model.get_predicted_antecedents(
48
            top_antecedents, top_antecedent_scores)
49
        predicted_clusters, _ = self.model.get_predicted_clusters(
50
            top_span_starts, top_span_ends, predicted_antecedents)
51
52
        return predicted_clusters
53
54
    def end_session(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
55
        self.session.close()
56
        tf.reset_default_graph()
57
58
59
def get_parser():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
60
    parser = argparse.ArgumentParser()
61
    parser.add_argument('config')
62
    parser.add_argument('input_filename')
63
    parser.add_argument('-o', '--output_file',
64
                        type=argparse.FileType('w'), default=sys.stdout)
65
    parser.add_argument('-f', '--format_out', default='conll',
66
                        choices=['conll', 'jsonlines', 'naf'])
67
    parser.add_argument('-c', '--word_col', type=int, default=2)
68
    parser.add_argument('--cfg_file',
69
                        type=str,
70
                        default=None,
71
                        help="config file")
72
    parser.add_argument('-v', '--verbose', action='store_true')
73
    return parser
74
75
76
def read_jsonlines(input_filename):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
77
    for line in open(input_filename).readlines():
78
        example = json.loads(line)
79
        yield example
80
81
82
def main(args=None):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Comprehensibility introduced by
This function exceeds the maximum number of variables (21/15).
Loading history...
83
    parser = get_parser()
84
    args = parser.parse_args()
85
    if args.verbose:
86
        logging.basicConfig(level=logging.DEBUG)
87
    # config = util.initialize_from_env(args.config, args.cfg_file)
88
89
    # Input file in .jsonlines format or .conll.
90
    input_filename = args.input_filename
91
92
    ext_input = os.path.splitext(input_filename)[-1]
93
    if ext_input not in ['.conll', '.jsonlines', '.txt', '.naf']:
94
        raise Exception(
95
            'Input file should be .naf, .conll, .txt or .jsonlines, but is {}.'
96
            .format(ext_input))
97
98
    if ext_input == '.conll':
99
        labels = collections.defaultdict(set)
100
        stats = collections.defaultdict(int)
101
        docs = minimize.minimize_partition(
102
            input_filename, labels, stats, args.word_col)
103
    elif ext_input == '.jsonlines':
104
        docs = read_jsonlines(input_filename)
105
    elif ext_input == '.naf':
106
        naf_obj = naf.get_naf(input_filename)
107
        jsonlines_obj, term_ids, tok_ids = naf.get_jsonlines(naf_obj)
0 ignored issues
show
Unused Code introduced by
The variable tok_ids seems to be unused.
Loading history...
108
        docs = [jsonlines_obj]
109
    else:
110
        text = open(input_filename).read()
111
        docs = [util.create_example(text)]
112
113
    output_file = args.output_file
114
    predictor = Predictor(args.config, args.cfg_file)
115
    sentences = {}
116
    predictions = {}
117
    for example_num, example in enumerate(docs):
118
        # logging.info(example['doc_key'])
119
        example["predicted_clusters"], _ = predictor.predict(example)
120
        if args.format_out == 'jsonlines':
121
            output_file.write(json.dumps(example))
122
            output_file.write("\n")
123
        else:
124
            predictions[example['doc_key']] = example["predicted_clusters"]
125
            sentences[example['doc_key']] = example["sentences"]
126
        if example_num % 100 == 0:
127
            logging.info("Decoded {} examples.".format(example_num + 1))
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
128
    if args.format_out == 'conll':
129
        conll.output_conll(output_file, sentences, predictions)
130
    elif args.format_out == 'naf':
131
        # Check number of docs - what to do if multiple?
132
        # Create naf obj if input format was not naf
133
        if ext_input != '.naf':
134
            # To do: add linguistic processing layers for terms and tokens
135
            logging.warn(
0 ignored issues
show
introduced by
Using deprecated method warn()
Loading history...
Coding Style Best Practice introduced by
Use lazy % formatting in logging functions
Loading history...
136
                'Outputting NAF when input was not naf,'
137
                + 'no dependency information available')
138
            for doc_key in sentences:
139
                naf_obj, term_ids = naf.get_naf_from_sentences(
140
                    sentences[doc_key])
141
                naf_obj = naf.create_coref_layer(
142
                    naf_obj, predictions[doc_key], term_ids)
143
                naf_obj = naf.add_linguistic_processors(naf_obj)
144
                buffer = io.BytesIO()
145
                naf_obj.dump(buffer)
146
                output_file.write(buffer.getvalue().decode('utf-8'))
147
                # To do, make sepearate outputs?
148
                # TO do, use dependency information from conll?
149
        else:
150
            # We only have one input doc
151
            naf_obj = naf.create_coref_layer(
152
                naf_obj, example["predicted_clusters"], term_ids)
0 ignored issues
show
introduced by
The variable example does not seem to be defined in case the for loop on line 117 is not entered. Are you sure this can never be the case?
Loading history...
introduced by
The variable naf_obj does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable term_ids does not seem to be defined for all execution paths.
Loading history...
Bug introduced by
The loop variable example might not be defined here.
Loading history...
153
            naf_obj = naf.add_linguistic_processors(naf_obj)
154
            buffer = io.BytesIO()
155
            naf_obj.dump(buffer)
156
            output_file.write(buffer.getvalue().decode('utf-8'))
157
158
159
if __name__ == "__main__":
160
    main()
161