sutime.sutime.SUTime._load_java_wrapper_class()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 16
rs 9.75
c 0
b 0
f 0
cc 4
nop 2
1
# -*- coding: utf-8 -*-
2
"""A Python wrapper for Stanford CoreNLP's SUTime."""
3
4
import glob
5
import importlib
6
import json
7
import logging
8
import os
9
import socket
10
import sys
11
import threading
12
from pathlib import Path
13
from typing import Dict, List, Optional
14
15
import jpype  # pyre-ignore[21]
16
17
SOCKED_DEFAULT_TIMEOUT = 15
18
socket.setdefaulttimeout(SOCKED_DEFAULT_TIMEOUT)
19
20
21
class SUTime(object):
22
    """Python wrapper for SUTime (CoreNLP) by Stanford."""
23
24
    _sutime_python_jar = 'stanford-corenlp-sutime-python-1.4.0.jar'
25
    _sutime_java_class = 'edu.stanford.nlp.python.SUTimeWrapper'
26
    _corenlp_version = '4.0.0'
27
28
    # full name or ISO 639-1 code
29
    _languages = {
30
        'arabic': 'arabic',
31
        'ar': 'arabic',
32
        'chinese': 'chinese',
33
        'zh': 'chinese',
34
        'english': 'english',
35
        'british': 'british',
36
        'en': 'english',
37
        'french': 'french',
38
        'fr': 'french',
39
        'german': 'german',
40
        'de': 'german',
41
        'spanish': 'spanish',
42
        'es': 'spanish',
43
    }
44
45
    # https://github.com/stanfordnlp/CoreNLP/tree/master/src/edu/stanford/nlp/time/rules
46
    _supported_languages = {'british', 'english', 'spanish'}
47
48
    _required_jars = {
49
        'stanford-corenlp-{0}-models.jar'.format('4.0.0'),
50
        'stanford-corenlp-{0}.jar'.format('4.0.0'),
51
        'gson-2.8.6.jar',
52
        'slf4j-simple-1.7.30.jar',
53
    }
54
55
    def __init__(
56
        self,
57
        jars: Optional[str] = None,
58
        jvm_started: Optional[bool] = False,
59
        mark_time_ranges: Optional[bool] = False,
60
        include_range: Optional[bool] = False,
61
        jvm_flags: Optional[List[str]] = None,
62
        language: Optional[str] = 'english',
63
    ):
64
        """Initialize `SUTime` wrapper.
65
66
        Args:
67
            jars (Optional[str]): Path to previously downloaded SUTime Java
68
                dependencies. Defaults to False.
69
            jvm_started (Optional[bool]): Flag to indicate that JVM has been
70
                already started (with all Java dependencies loaded). Defaults
71
                to False.
72
            mark_time_ranges (Optional[bool]): SUTime flag for
73
                sutime.markTimeRanges. Defaults to False.
74
                "Whether or not to recognize time ranges such as 'July to
75
                August'"
76
            include_range (Optional[bool]): SUTime flag for
77
                sutime.includeRange. Defaults to False.
78
                "Whether or not to add range info to the TIMEX3 object"
79
            jvm_flags (Optional[List[str]]): List of flags passed to JVM. For
80
                example, this may be used to specify the maximum heap size
81
                using '-Xmx'. Has no effect if `jvm_started` is set to True.
82
                Defaults to None.
83
            language (Optional[str]): Selected language. Currently supported
84
                are: english (/en), british, spanish (/es). Defaults to
85
                `english`.
86
        """
87
        self.mark_time_ranges = mark_time_ranges
88
        self.include_range = include_range
89
        self._is_loaded = False
90
        self._sutime = None
91
        self._lock = threading.Lock()
92
        module_root = Path(__file__).resolve().parent
93
        self.jars = Path(jars) if jars else module_root / 'jars'
94
95
        self._check_language_model_dependency(
96
            language.lower() if language else '',
97
        )
98
99
        if not jvm_started:
100
            self._classpath = self._create_classpath()
101
            self._start_jvm(jvm_flags)
102
103
        self._load_java_wrapper_class(language)
104
105
    def parse(
106
        self, input_str: str, reference_date: Optional[str] = '',
107
    ) -> List[Dict]:
108
        """Parse datetime information out of string input.
109
110
        It invokes the SUTimeWrapper.annotate() function in Java.
111
112
        Args:
113
            input_str (str): The input as string that has to be parsed.
114
            reference_date (Optional[str]): Optional reference data for SUTime.
115
                Defaults to `''`.
116
117
        Returns:
118
            A list of dicts with the result from the `SUTimeWrapper.annotate()`
119
            call.
120
121
        Raises:
122
            RuntimeError: An error occurs when CoreNLP is not loaded.
123
        """
124
        if self._is_loaded is False:
125
            raise RuntimeError('Please load SUTime first!')
126
127
        if reference_date:
128
            return json.loads(str(self._sutime.annotate(
129
                input_str, reference_date,
130
            )))
131
        return json.loads(str(self._sutime.annotate(input_str)))
132
133
    def _load_java_wrapper_class(self, language: Optional[str]):
134
        try:
135
            # make it thread-safe
136
            if threading.active_count() > 1:
137
                if not jpype.isThreadAttachedToJVM():
138
                    jpype.attachThreadToJVM()
139
            self._lock.acquire()
140
            wrapper = jpype.JClass(self._sutime_java_class)
141
            self._sutime = wrapper(
142
                self.mark_time_ranges, self.include_range, language,
143
            )
144
            self._is_loaded = True
145
        except Exception as exc:
146
            sys.exit('Could not load JVM: {0}'.format(exc))
147
        finally:
148
            self._lock.release()
149
150
    def _check_language_model_dependency(self, language: str):
151
        if language not in self._languages:
152
            raise RuntimeError('Unsupported language: {0}'.format(language))
153
        normalized_language = self._languages[language]
154
155
        if normalized_language not in self._supported_languages:
156
            logging.warning('{0}: {1}. {2}.'.format(
157
                normalized_language.capitalize(),
158
                'is not (yet) supported by SUTime',
159
                'Falling back to default model',
160
            ))
161
            return
162
163
        language_model_file = (
164
            self.jars / 'stanford-corenlp-{0}-models-{1}.jar'.format(
165
                self._corenlp_version,
166
                normalized_language,
167
            ))
168
169
        language_model_file_exists = glob.glob(str(language_model_file))
170
        is_english_language = normalized_language in {'english', 'british'}
171
172
        if not (language_model_file_exists or is_english_language):
173
            raise RuntimeError(
174
                'Missing language model for {0}! Run {1} {2} {3}'.format(
175
                    self._languages[language].capitalize(),
176
                    'mvn dependency:copy-dependencies',
177
                    '-DoutputDirectory=./sutime/jars -P',
178
                    self._languages[language],
179
                ),
180
            )
181
182
    def _start_jvm(self, additional_flags: Optional[List[str]]):
183
        flags = ['-Djava.class.path={0}'.format(self._classpath)]
184
        if additional_flags:
185
            flags.extend(additional_flags)
186
        logging.info('jpype.isJVMStarted(): {0}'.format(jpype.isJVMStarted()))
187
        if not jpype.isJVMStarted():
188
            jpype.startJVM(jpype.getDefaultJVMPath(), *flags)
189
190
    def _create_classpath(self):
191
        sutime_jar = (
192
            Path(importlib.util.find_spec('sutime').origin).parent /
193
            'jars' / self._sutime_python_jar
194
        )
195
        jars = [sutime_jar]
196
        jar_file_names = []
197
        for top, _, files in os.walk(self.jars):
198
            for file_name in files:
199
                if file_name.endswith('.jar'):
200
                    jars.append(Path(top, file_name))
201
                    jar_file_names.append(file_name)
202
        if not self._required_jars.issubset(jar_file_names):
203
            logging.warning([
204
                jar for jar in self._required_jars if jar not in jar_file_names
205
            ])
206
            raise RuntimeError(
207
                'Not all necessary Java dependencies have been downloaded!',
208
            )
209
        return os.pathsep.join(str(jar) for jar in jars)
210