1
|
|
|
import os |
|
|
|
|
2
|
|
|
import stanza |
3
|
|
|
import logging |
|
|
|
|
4
|
|
|
|
5
|
|
|
from pathlib import Path |
|
|
|
|
6
|
|
|
|
7
|
|
|
from e2edutch import util |
8
|
|
|
from e2edutch import coref_model as cm |
|
|
|
|
9
|
|
|
from e2edutch.download import download_data |
10
|
|
|
from e2edutch.predict import Predictor |
11
|
|
|
|
12
|
|
|
from stanza.pipeline.processor import Processor, register_processor |
|
|
|
|
13
|
|
|
from stanza.models.common.doc import Document, Span |
|
|
|
|
14
|
|
|
|
15
|
|
|
|
16
|
|
|
# Add a Clusters property to documents as a List of List of Span: |
17
|
|
|
# Clusters is a List of cluster, cluster is a List of Span |
18
|
|
|
def clusterSetter(self, value): |
|
|
|
|
19
|
|
|
if isinstance(value, type([])): |
20
|
|
|
self._clusters = value |
|
|
|
|
21
|
|
|
else: |
22
|
|
|
logger.error('Clusters must be a List') |
23
|
|
|
|
24
|
|
|
stanza.models.common.doc.Document.add_property('clusters', default='[]', setter=clusterSetter) |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
28
|
|
|
|
29
|
|
|
logger = logging.getLogger('e2edutch') |
30
|
|
|
logger.setLevel(logging.INFO) |
31
|
|
|
logger.addHandler(logging.StreamHandler()) |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
@register_processor('coref') |
35
|
|
|
class CorefProcessor(Processor): |
36
|
|
|
''' Processor that appends coreference information ''' |
37
|
|
|
_requires = set(['tokenize']) |
38
|
|
|
_provides = set(['coref']) |
39
|
|
|
|
|
|
|
|
40
|
|
|
def __init__(self, config, pipeline, use_gpu): |
|
|
|
|
41
|
|
|
# Make e2edutch follow Stanza's GPU settings: |
42
|
|
|
# set the environment value for GPU, so that initialize_from_env picks it up. |
43
|
|
|
#if use_gpu: |
44
|
|
|
# os.environ['GPU'] = ' '.join(tf.config.experimental.list_physical_devices('GPU')) |
45
|
|
|
#else: |
46
|
|
|
# if 'GPU' in os.environ['GPU'] : |
47
|
|
|
# os.environ.pop('GPU') |
48
|
|
|
|
49
|
|
|
self.e2econfig = util.initialize_from_env(model_name='final') |
50
|
|
|
|
51
|
|
|
# Override datapath and log_root: |
52
|
|
|
# store e2edata with the Stanza resources, ie. a 'stanza_resources/nl/coref' directory |
53
|
|
|
self.e2econfig['datapath'] = Path(config['model_path']).parent |
54
|
|
|
self.e2econfig['log_root'] = Path(config['model_path']).parent |
55
|
|
|
|
56
|
|
|
# Download data files if not present |
57
|
|
|
download_data(self.e2econfig) |
58
|
|
|
|
59
|
|
|
# Start and stop a session to cache all models |
60
|
|
|
predictor = Predictor(config=self.e2econfig) |
61
|
|
|
predictor.end_session() |
62
|
|
|
|
63
|
|
|
def _set_up_model(self, *args): |
|
|
|
|
64
|
|
|
print ('_set_up_model') |
|
|
|
|
65
|
|
|
pass |
|
|
|
|
66
|
|
|
|
67
|
|
|
def process(self, doc): |
|
|
|
|
68
|
|
|
|
69
|
|
|
predictor = Predictor(config=self.e2econfig) |
70
|
|
|
|
71
|
|
|
# build the example argument for predict: |
72
|
|
|
# example (dict): dict with the following fields: |
73
|
|
|
# sentences ([[str]]) |
74
|
|
|
# doc_id (str) |
75
|
|
|
# clusters ([[(int, int)]]) (optional) |
76
|
|
|
example = {} |
77
|
|
|
example['sentences'] = [] |
78
|
|
|
example['doc_id'] = 'document_from_stanza' # TODO check what this should be |
|
|
|
|
79
|
|
|
example['doc_key'] = 'undocumented' # TODO check what this should be |
|
|
|
|
80
|
|
|
|
81
|
|
|
for sentence in doc.sentences: |
82
|
|
|
s = [] |
|
|
|
|
83
|
|
|
for word in sentence.words: |
84
|
|
|
s.append(word.text) |
85
|
|
|
example['sentences'].append(s) |
86
|
|
|
|
87
|
|
|
predicted_clusters = predictor.predict(example) # a list of tuples |
88
|
|
|
|
89
|
|
|
# Add the predicted clusters back to the Stanza document |
90
|
|
|
|
91
|
|
|
clusters = [] |
92
|
|
|
for predicted_cluster in predicted_clusters: # a tuple of entities |
93
|
|
|
cluster = [] |
94
|
|
|
for predicted_reference in predicted_cluster: # a tuple of (start, end) word |
95
|
|
|
start, end = predicted_reference |
96
|
|
|
|
97
|
|
|
# find the sentence_id of the sentence containing this reference |
98
|
|
|
sentence_id = 0 |
99
|
|
|
sentence = doc.sentences[0] |
100
|
|
|
sentence_start_word = 0 |
101
|
|
|
sentence_end_word = len(sentence.words) - 1 |
102
|
|
|
|
103
|
|
|
while sentence_end_word < start: |
104
|
|
|
sentence_start_word = sentence_end_word + 1 |
105
|
|
|
|
106
|
|
|
# move to the next sentence |
107
|
|
|
sentence_id += 1 |
108
|
|
|
sentence = doc.sentences[sentence_id] |
109
|
|
|
|
110
|
|
|
sentence_end_word = sentence_start_word + len(sentence.words) - 1 |
111
|
|
|
|
112
|
|
|
# start counting words from the start of this sentence |
113
|
|
|
start -= sentence_start_word |
114
|
|
|
end -= sentence_start_word |
115
|
|
|
|
116
|
|
|
span = Span( # a list of Tokens |
117
|
|
|
tokens=[word.parent for word in sentence.words[start:end + 1]], |
|
|
|
|
118
|
|
|
doc=doc, |
|
|
|
|
119
|
|
|
type='COREF', |
|
|
|
|
120
|
|
|
sent=doc.sentences[sentence_id] |
|
|
|
|
121
|
|
|
) |
|
|
|
|
122
|
|
|
cluster.append(span) |
123
|
|
|
|
124
|
|
|
clusters.append(cluster) |
125
|
|
|
|
126
|
|
|
doc.clusters = clusters |
127
|
|
|
|
128
|
|
|
predictor.end_session() |
129
|
|
|
|
130
|
|
|
return doc |
131
|
|
|
|