1
|
|
|
import re |
|
|
|
|
2
|
|
|
import tempfile |
3
|
|
|
import subprocess |
4
|
|
|
import operator |
5
|
|
|
import collections |
6
|
|
|
|
7
|
|
|
BEGIN_DOCUMENT_REGEX = re.compile( |
8
|
|
|
r"#begin document \(?([^\);]*)\)?;?(?: part (\d+))?") |
9
|
|
|
COREF_RESULTS_REGEX = re.compile( |
10
|
|
|
r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL) |
|
|
|
|
11
|
|
|
|
12
|
|
|
|
13
|
|
|
def get_doc_key(doc_id, part=None): |
|
|
|
|
14
|
|
|
if part is None: |
|
|
|
|
15
|
|
|
return doc_id |
16
|
|
|
else: |
17
|
|
|
return '{}.p.{}'.format(doc_id, part) |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
def get_reverse_doc_key(doc_key): |
|
|
|
|
21
|
|
|
segments = doc_key.split('.p.') |
22
|
|
|
if len(segments) > 1: |
23
|
|
|
part = segments[-1] |
24
|
|
|
doc_id = '.p.'.join(segments[:-1]) |
25
|
|
|
else: |
26
|
|
|
doc_id = doc_key |
27
|
|
|
part = None |
28
|
|
|
return doc_id, part |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
def get_prediction_map(predictions): |
|
|
|
|
32
|
|
|
prediction_map = {} |
33
|
|
|
for doc_key, clusters in predictions.items(): |
34
|
|
|
start_map = collections.defaultdict(list) |
35
|
|
|
end_map = collections.defaultdict(list) |
36
|
|
|
word_map = collections.defaultdict(list) |
37
|
|
|
for cluster_id, mentions in enumerate(clusters): |
38
|
|
|
for start, end in mentions: |
39
|
|
|
if start == end: |
40
|
|
|
word_map[start].append(cluster_id) |
41
|
|
|
else: |
42
|
|
|
start_map[start].append((cluster_id, end)) |
43
|
|
|
end_map[end].append((cluster_id, start)) |
44
|
|
|
for k, v in start_map.items(): |
|
|
|
|
45
|
|
|
start_map[k] = [cluster_id for cluster_id, end in sorted( |
46
|
|
|
v, key=operator.itemgetter(1), reverse=True)] |
47
|
|
|
for k, v in end_map.items(): |
|
|
|
|
48
|
|
|
end_map[k] = [cluster_id for cluster_id, start in sorted( |
49
|
|
|
v, key=operator.itemgetter(1), reverse=True)] |
50
|
|
|
prediction_map[doc_key] = (start_map, end_map, word_map) |
51
|
|
|
return prediction_map |
52
|
|
|
|
53
|
|
|
|
54
|
|
|
def clusters_to_brackets(sentences, predictions): |
|
|
|
|
55
|
|
|
prediction_map = get_prediction_map({'': predictions}) |
56
|
|
|
start_map, end_map, word_map = prediction_map[''] |
57
|
|
|
word_index = 0 |
58
|
|
|
brackets_list = [] |
59
|
|
|
for sent in sentences: |
60
|
|
|
sent_brackets_list = [] |
61
|
|
|
for i, word in enumerate(sent): |
|
|
|
|
62
|
|
|
coref_list = [] |
63
|
|
|
if word_index in end_map: |
64
|
|
|
for cluster_id in end_map[word_index]: |
65
|
|
|
coref_list.append("{})".format(cluster_id)) |
66
|
|
|
if word_index in word_map: |
67
|
|
|
for cluster_id in word_map[word_index]: |
68
|
|
|
coref_list.append("({})".format(cluster_id)) |
69
|
|
|
if word_index in start_map: |
70
|
|
|
for cluster_id in start_map[word_index]: |
71
|
|
|
coref_list.append("({}".format(cluster_id)) |
72
|
|
|
coref = '-' if len(coref_list) == 0 else "|".join(coref_list) |
73
|
|
|
sent_brackets_list.append(coref) |
74
|
|
|
word_index += 1 |
75
|
|
|
brackets_list.append(sent_brackets_list) |
76
|
|
|
return brackets_list |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
def output_conll(output_file, sentences, predictions): |
80
|
|
|
""" |
81
|
|
|
Output the tokens and coreferences in CONLL-2012 format |
82
|
|
|
|
83
|
|
|
Args: |
84
|
|
|
output_file (File or IOBase): File to write the CONLL to |
85
|
|
|
sentences (dict): keys are the doc_keys, values are the sentences of |
86
|
|
|
that doc |
87
|
|
|
predictions (dict): keys are the doc_keys, values are the predicted |
88
|
|
|
clusters of that doc |
89
|
|
|
""" |
90
|
|
|
for doc_key in sentences: |
91
|
|
|
brackets = clusters_to_brackets(sentences[doc_key], predictions[doc_key]) |
92
|
|
|
doc_id, part = get_reverse_doc_key(doc_key) |
93
|
|
|
if part is None: |
94
|
|
|
output_file.write("#begin document ({});\n\n".format(doc_id)) |
95
|
|
|
else: |
96
|
|
|
output_file.write( |
97
|
|
|
"#begin document ({}); part {}\n\n".format( |
98
|
|
|
doc_id, part)) |
99
|
|
|
for sent, brack_sent in zip(sentences[doc_key], brackets): |
100
|
|
|
for i, word in enumerate(sent): |
101
|
|
|
coref = brack_sent[i] |
102
|
|
|
line = '\t'.join([doc_id, str(i), word, coref]) |
103
|
|
|
output_file.write(line + '\n') |
104
|
|
|
output_file.write('\n') |
105
|
|
|
output_file.write('#end document\n') |
106
|
|
|
|
107
|
|
|
|
108
|
|
|
def output_conll_align(input_file, output_file, predictions): |
|
|
|
|
109
|
|
|
prediction_map = get_prediction_map(predictions) |
110
|
|
|
|
111
|
|
|
word_index = 0 |
112
|
|
|
for line in input_file.readlines(): |
113
|
|
|
row = line.split() |
114
|
|
|
if len(row) == 0: |
115
|
|
|
output_file.write("\n") |
116
|
|
|
elif row[0].startswith("#"): |
117
|
|
|
begin_match = re.match(BEGIN_DOCUMENT_REGEX, line) |
118
|
|
|
if begin_match: |
119
|
|
|
doc_key = get_doc_key(*begin_match.groups()) |
120
|
|
|
start_map, end_map, word_map = prediction_map[doc_key] |
121
|
|
|
word_index = 0 |
122
|
|
|
output_file.write(line) |
123
|
|
|
output_file.write("\n") |
124
|
|
|
else: |
125
|
|
|
coref_list = [] |
126
|
|
|
if word_index in end_map: |
|
|
|
|
127
|
|
|
for cluster_id in end_map[word_index]: |
128
|
|
|
coref_list.append("{})".format(cluster_id)) |
129
|
|
|
if word_index in word_map: |
|
|
|
|
130
|
|
|
for cluster_id in word_map[word_index]: |
131
|
|
|
coref_list.append("({})".format(cluster_id)) |
132
|
|
|
if word_index in start_map: |
|
|
|
|
133
|
|
|
for cluster_id in start_map[word_index]: |
134
|
|
|
coref_list.append("({}".format(cluster_id)) |
135
|
|
|
|
136
|
|
|
if len(coref_list) == 0: |
137
|
|
|
row[-1] = "-" |
138
|
|
|
else: |
139
|
|
|
row[-1] = "|".join(coref_list) |
140
|
|
|
|
141
|
|
|
output_file.write(" ".join(row)) |
142
|
|
|
output_file.write("\n") |
143
|
|
|
word_index += 1 |
144
|
|
|
|
145
|
|
|
|
146
|
|
|
def official_conll_eval(gold_path, predicted_path, |
|
|
|
|
147
|
|
|
metric, official_stdout=False): |
148
|
|
|
cmd = ["conll-2012/scorer/v8.01/scorer.pl", |
149
|
|
|
metric, gold_path, predicted_path, "none"] |
150
|
|
|
process = subprocess.Popen(cmd, stdout=subprocess.PIPE) |
151
|
|
|
stdout, stderr = process.communicate() |
152
|
|
|
process.wait() |
153
|
|
|
|
154
|
|
|
stdout = stdout.decode("utf-8") |
155
|
|
|
if stderr is not None: |
156
|
|
|
print(stderr) |
157
|
|
|
|
158
|
|
|
if official_stdout: |
159
|
|
|
print("Official result for {}".format(metric)) |
160
|
|
|
print(stdout) |
161
|
|
|
|
162
|
|
|
coref_results_match = re.match(COREF_RESULTS_REGEX, stdout) |
163
|
|
|
recall = float(coref_results_match.group(1)) |
164
|
|
|
precision = float(coref_results_match.group(2)) |
165
|
|
|
f1 = float(coref_results_match.group(3)) |
|
|
|
|
166
|
|
|
return {"r": recall, "p": precision, "f": f1} |
167
|
|
|
|
168
|
|
|
|
169
|
|
|
def evaluate_conll(gold_path, predictions, official_stdout=False): |
|
|
|
|
170
|
|
|
with tempfile.NamedTemporaryFile(delete=False, mode="w") as pred_file: |
171
|
|
|
with open(gold_path, "r") as gold_file: |
172
|
|
|
output_conll_align(gold_file, pred_file, predictions) |
173
|
|
|
print("Predicted conll file: {}".format(pred_file.name)) |
174
|
|
|
return {m: official_conll_eval( |
175
|
|
|
gold_file.name, pred_file.name, m, official_stdout) |
176
|
|
|
for m in ("muc", "bcub", "ceafe")} |
|
|
|
|
177
|
|
|
|