e2edutch.predict.Predictor.predict()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 28
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 28
rs 9.8
c 0
b 0
f 0
cc 1
nop 2
1
import sys
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import json
3
import os
4
import io
5
import collections
6
import argparse
7
import logging
8
9
from e2edutch import conll
10
from e2edutch import minimize
11
from e2edutch import util
12
from e2edutch import coref_model as cm
13
from e2edutch import naf
14
15
import tensorflow.compat.v1 as tf
0 ignored issues
show
introduced by
Unable to import 'tensorflow.compat.v1'
Loading history...
introduced by
third party import "import tensorflow.compat.v1 as tf" should be placed before "from e2edutch import conll"
Loading history...
16
17
logger = logging.getLogger('e2edutch')
18
19
20
class Predictor(object):
0 ignored issues
show
introduced by
Class 'Predictor' inherits from object, can be safely removed from bases in python3
Loading history...
21
    """
22
    A predictor object loads a pretrained e2e model to predict coreferences.
23
    It can be used to predict coreferences on tokenized text.
24
    """
25
26
    def __init__(self, model_name='final', config=None, verbose=False):
27
        if verbose:
28
            logger.setLevel(logging.INFO)
29
30
        if config:
31
            self.config = config
32
        else:
33
            # if no configuration is provided, try to get a default config.
34
            self.config = util.initialize_from_env(model_name=model_name)
35
36
        # Clear tensorflow context:
37
        tf.reset_default_graph()
38
        self.session = tf.compat.v1.Session()
39
40
        try:
41
            self.model = cm.CorefModel(self.config)
42
            self.model.restore(self.session)
43
        except ValueError:
44
            raise Exception("Trying to reload the model while the previous " +
45
                            "session hasn't been ended. Close the existing " +
46
                            "session with predictor.end_session()")
47
48
    def predict(self, example):
49
        """
50
        Predict coreference spans for a tokenized text.
51
52
53
        Args:
54
            example (dict): dict with the following fields:
55
                              sentences ([[str]])
56
                              doc_id (str)
57
                              clusters ([[(int, int)]]) (optional)
58
59
        Returns:
60
            [[(int, int)]]: a list of clusters. The items of the cluster are
61
                            spans, denoted by their start end end token index
62
63
        """
64
        tensorized_example = self.model.tensorize_example(
65
            example, is_training=False)
66
        feed_dict = {i: t for i, t in zip(
0 ignored issues
show
Unused Code introduced by
Unnecessary use of a comprehension
Loading history...
67
            self.model.input_tensors, tensorized_example)}
68
        _, _, _, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = self.session.run(
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
69
            self.model.predictions, feed_dict=feed_dict)
70
        predicted_antecedents = self.model.get_predicted_antecedents(
71
            top_antecedents, top_antecedent_scores)
72
        predicted_clusters, _ = self.model.get_predicted_clusters(
73
            top_span_starts, top_span_ends, predicted_antecedents)
74
75
        return predicted_clusters
76
77
    def end_session(self):
78
        """
79
        Close the session, clearing the tensorflow model context.
80
        """
81
        self.session.close()
82
        tf.reset_default_graph()
83
84
85
def get_parser():
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
86
    parser = argparse.ArgumentParser()
87
    parser.add_argument('config')
88
    parser.add_argument('input_filename')
89
    parser.add_argument('-o', '--output_file',
90
                        type=argparse.FileType('w'), default=sys.stdout)
91
    parser.add_argument('-f', '--format_out', default='conll',
92
                        choices=['conll', 'jsonlines', 'naf'])
93
    parser.add_argument('-c', '--word_col', type=int, default=2)
94
    parser.add_argument('--cfg_file',
95
                        type=str,
96
                        default=None,
97
                        help="config file")
98
    parser.add_argument('-v', '--verbose', action='store_true')
99
    return parser
100
101
102
def read_jsonlines(input_filename):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
103
    for line in open(input_filename).readlines():
104
        example = json.loads(line)
105
        yield example
106
107
108
def main(args=None):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Comprehensibility introduced by
This function exceeds the maximum number of variables (22/15).
Loading history...
109
    parser = get_parser()
110
    args = parser.parse_args()
111
    if args.verbose:
112
        logger.setLevel(logging.INFO)
113
114
    # Input file in .jsonlines format or .conll.
115
    input_filename = args.input_filename
116
117
    ext_input = os.path.splitext(input_filename)[-1]
118
    if ext_input not in ['.conll', '.jsonlines', '.txt', '.naf']:
119
        raise Exception(
120
            'Input file should be .naf, .conll, .txt or .jsonlines, but is {}.'
121
            .format(ext_input))
122
123
    if ext_input == '.conll':
124
        labels = collections.defaultdict(set)
125
        stats = collections.defaultdict(int)
126
        docs = minimize.minimize_partition(
127
            input_filename, labels, stats, args.word_col)
128
    elif ext_input == '.jsonlines':
129
        docs = read_jsonlines(input_filename)
130
    elif ext_input == '.naf':
131
        naf_obj = naf.get_naf(input_filename)
132
        jsonlines_obj, term_ids, tok_ids = naf.get_jsonlines(naf_obj)
0 ignored issues
show
Unused Code introduced by
The variable tok_ids seems to be unused.
Loading history...
133
        docs = [jsonlines_obj]
134
    else:
135
        text = open(input_filename).read()
136
        docs = [util.create_example(text)]
137
138
    output_file = args.output_file
139
140
    config = util.initialize_from_env(cfg_file=args.cfg_file, model_cfg_file=args.config)
141
    predictor = Predictor(config=config)
142
143
    sentences = {}
144
    predictions = {}
145
    for example_num, example in enumerate(docs):
146
        example["predicted_clusters"], _ = predictor.predict(example)
147
        if args.format_out == 'jsonlines':
148
            output_file.write(json.dumps(example))
149
            output_file.write("\n")
150
        else:
151
            predictions[example['doc_key']] = example["predicted_clusters"]
152
            sentences[example['doc_key']] = example["sentences"]
153
        if example_num % 100 == 0:
154
            logger.info("Decoded {} examples.".format(example_num + 1))
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
155
    if args.format_out == 'conll':
156
        conll.output_conll(output_file, sentences, predictions)
157
    elif args.format_out == 'naf':
158
        # Check number of docs - what to do if multiple?
159
        # Create naf obj if input format was not naf
160
        if ext_input != '.naf':
161
            # To do: add linguistic processing layers for terms and tokens
162
            logger.warn(
0 ignored issues
show
Coding Style Best Practice introduced by
Use lazy % formatting in logging functions
Loading history...
introduced by
Using deprecated method warn()
Loading history...
163
                'Outputting NAF when input was not naf,'
164
                + 'no dependency information available')
165
            for doc_key in sentences:
166
                naf_obj, term_ids = naf.get_naf_from_sentences(
167
                    sentences[doc_key])
168
                naf_obj = naf.create_coref_layer(
169
                    naf_obj, predictions[doc_key], term_ids)
170
                naf_obj = naf.add_linguistic_processors(naf_obj)
171
                buffer = io.BytesIO()
172
                naf_obj.dump(buffer)
173
                output_file.write(buffer.getvalue().decode('utf-8'))
174
                # To do, make sepearate outputs?
175
                # TO do, use dependency information from conll?
176
        else:
177
            # We only have one input doc
178
            naf_obj = naf.create_coref_layer(
179
                naf_obj, example["predicted_clusters"], term_ids)
0 ignored issues
show
introduced by
The variable term_ids does not seem to be defined for all execution paths.
Loading history...
Bug introduced by
The loop variable example might not be defined here.
Loading history...
introduced by
The variable naf_obj does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable example does not seem to be defined in case the for loop on line 145 is not entered. Are you sure this can never be the case?
Loading history...
180
            naf_obj = naf.add_linguistic_processors(naf_obj)
181
            buffer = io.BytesIO()
182
            naf_obj.dump(buffer)
183
            output_file.write(buffer.getvalue().decode('utf-8'))
184
185
186
if __name__ == "__main__":
187
    main()
188