1
|
|
|
import os |
|
|
|
|
2
|
|
|
import stanza |
|
|
|
|
3
|
|
|
import logging |
|
|
|
|
4
|
|
|
|
5
|
|
|
from pathlib import Path |
|
|
|
|
6
|
|
|
|
7
|
|
|
from e2edutch import util |
8
|
|
|
from e2edutch import coref_model as cm |
|
|
|
|
9
|
|
|
from e2edutch.download import download_data |
10
|
|
|
from e2edutch.predict import Predictor |
11
|
|
|
|
12
|
|
|
from stanza.pipeline.processor import Processor, register_processor |
|
|
|
|
13
|
|
|
from stanza.models.common.doc import Document |
|
|
|
|
14
|
|
|
|
15
|
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
16
|
|
|
|
17
|
|
|
logger = logging.getLogger('e2edutch') |
18
|
|
|
logger.setLevel(logging.INFO) |
19
|
|
|
logger.addHandler(logging.StreamHandler()) |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
@register_processor('coref') |
23
|
|
|
class CorefProcessor(Processor): |
24
|
|
|
''' Processor that appends coreference information ''' |
25
|
|
|
_requires = set(['tokenize']) |
26
|
|
|
_provides = set(['coref']) |
27
|
|
|
|
|
|
|
|
28
|
|
|
def __init__(self, config, pipeline, use_gpu): |
|
|
|
|
29
|
|
|
# Make e2edutch follow Stanza's GPU settings: |
30
|
|
|
# set the environment value for GPU, so that initialize_from_env picks it up. |
31
|
|
|
#if use_gpu: |
32
|
|
|
# os.environ['GPU'] = ' '.join(tf.config.experimental.list_physical_devices('GPU')) |
33
|
|
|
#else: |
34
|
|
|
# if 'GPU' in os.environ['GPU'] : |
35
|
|
|
# os.environ.pop('GPU') |
36
|
|
|
|
37
|
|
|
self.e2econfig = util.initialize_from_env(model_name='final') |
38
|
|
|
|
39
|
|
|
# Override datapath and log_root: |
40
|
|
|
# store e2edata with the Stanza resources, ie. a 'stanza_resources/nl/coref' directory |
41
|
|
|
self.e2econfig['datapath'] = Path(config['model_path']).parent |
42
|
|
|
self.e2econfig['log_root'] = Path(config['model_path']).parent |
43
|
|
|
|
44
|
|
|
# Download data files if not present |
45
|
|
|
download_data(self.e2econfig) |
46
|
|
|
|
47
|
|
|
# Start and stop a session to cache all models |
48
|
|
|
predictor = Predictor(config=self.e2econfig) |
49
|
|
|
predictor.end_session() |
50
|
|
|
|
51
|
|
|
def _set_up_model(self, *args): |
|
|
|
|
52
|
|
|
print ('_set_up_model') |
|
|
|
|
53
|
|
|
pass |
|
|
|
|
54
|
|
|
|
55
|
|
|
def process(self, doc): |
56
|
|
|
|
57
|
|
|
predictor = Predictor(config=self.e2econfig) |
58
|
|
|
|
59
|
|
|
# build the example argument for predict: |
60
|
|
|
# example (dict): dict with the following fields: |
61
|
|
|
# sentences ([[str]]) |
62
|
|
|
# doc_id (str) |
63
|
|
|
# clusters ([[(int, int)]]) (optional) |
64
|
|
|
example = {} |
65
|
|
|
example['sentences'] = [sentence.text for sentence in doc.sentences] |
66
|
|
|
example['doc_id'] = 'document_from_stanza' |
67
|
|
|
example['doc_key'] = 'undocumented' |
68
|
|
|
|
69
|
|
|
# predicted_clusters, _ = predictor.predict(example) |
70
|
|
|
print(predictor.predict(example)) |
71
|
|
|
|
72
|
|
|
predictor.end_session() |
73
|
|
|
|
74
|
|
|
return doc |
75
|
|
|
|