Passed
Pull Request — master (#663)
by Juho
02:48
created

annif.project.AnnifProject._initialize_backend()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 2
dl 0
loc 9
rs 9.95
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from shutil import rmtree
6
7
import annif
8
import annif.analyzer
9
import annif.backend
10
import annif.corpus
11
import annif.suggestion
12
import annif.transform
13
from annif.datadir import DatadirMixin
14
from annif.exception import (
15
    AnnifException,
16
    ConfigurationException,
17
    NotInitializedException,
18
    NotSupportedException,
19
)
20
21
logger = annif.logger
22
23
24
class Access(enum.IntEnum):
25
    """Enumeration of access levels for projects"""
26
27
    private = 1
28
    hidden = 2
29
    public = 3
30
31
32
class AnnifProject(DatadirMixin):
33
    """Class representing the configuration of a single Annif project."""
34
35
    # defaults for uninitialized instances
36
    _transform = None
37
    _analyzer = None
38
    _backend = None
39
    _vocab = None
40
    _vocab_lang = None
41
    initialized = False
42
43
    # default values for configuration settings
44
    DEFAULT_ACCESS = "public"
45
46
    def __init__(self, project_id, config, datadir, registry):
47
        DatadirMixin.__init__(self, datadir, "projects", project_id)
48
        self.project_id = project_id
49
        self.name = config.get("name", project_id)
50
        self.language = config["language"]
51
        self.analyzer_spec = config.get("analyzer", None)
52
        self.transform_spec = config.get("transform", "pass")
53
        self.vocab_spec = config.get("vocab", None)
54
        self.config = config
55
        self._base_datadir = datadir
56
        self.registry = registry
57
        self._init_access()
58
59
    def _init_access(self):
60
        access = self.config.get("access", self.DEFAULT_ACCESS)
61
        try:
62
            self.access = getattr(Access, access)
63
        except AttributeError:
64
            raise ConfigurationException(
65
                "'{}' is not a valid access setting".format(access),
66
                project_id=self.project_id,
67
            )
68
69
    def _initialize_analyzer(self):
70
        if not self.analyzer_spec:
71
            return  # not configured, so assume it's not needed
72
        analyzer = self.analyzer
73
        logger.debug(
74
            "Project '%s': initialized analyzer: %s", self.project_id, str(analyzer)
75
        )
76
77
    def _initialize_subjects(self):
78
        try:
79
            subjects = self.subjects
80
            logger.debug(
81
                "Project '%s': initialized subjects: %s", self.project_id, str(subjects)
82
            )
83
        except AnnifException as err:
84
            logger.warning(err.format_message())
85
86
    def _initialize_backend(self, parallel):
87
        logger.debug("Project '%s': initializing backend", self.project_id)
88
        try:
89
            if not self.backend:
90
                logger.debug("Cannot initialize backend: does not exist")
91
                return
92
            self.backend.initialize(parallel)
93
        except AnnifException as err:
94
            logger.warning(err.format_message())
95
96
    def initialize(self, parallel=False):
97
        """Initialize this project and its backend so that they are ready to
98
        be used. If parallel is True, expect that the project will be used
99
        for parallel processing."""
100
101
        if self.initialized:
102
            return
103
104
        logger.debug("Initializing project '%s'", self.project_id)
105
106
        self._initialize_analyzer()
107
        self._initialize_subjects()
108
        self._initialize_backend(parallel)
109
110
        self.initialized = True
111
112
    def _suggest_with_backend(self, text, backend_params):
113
        if backend_params is None:
114
            backend_params = {}
115
        beparams = backend_params.get(self.backend.backend_id, {})
116
        hits = self.backend.suggest(text, beparams)
117
        logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id)
118
        return hits
119
120
    def _suggest_batch_with_backend(self, corpus, backend_params):
121
        if backend_params is None:
122
            backend_params = {}
123
        beparams = backend_params.get(self.backend.backend_id, {})
124
        hit_sets = self.backend.suggest_batch(corpus, beparams)
125
        logger.debug(
126
            "Got %d hit sets from backend %s", len(hit_sets), self.backend.backend_id
127
        )
128
        return hit_sets
129
130
    @property
131
    def analyzer(self):
132
        if self._analyzer is None:
133
            if self.analyzer_spec:
134
                self._analyzer = annif.analyzer.get_analyzer(self.analyzer_spec)
135
            else:
136
                raise ConfigurationException(
137
                    "analyzer setting is missing", project_id=self.project_id
138
                )
139
        return self._analyzer
140
141
    @property
142
    def transform(self):
143
        if self._transform is None:
144
            self._transform = annif.transform.get_transform(
145
                self.transform_spec, project=self
146
            )
147
        return self._transform
148
149
    @property
150
    def backend(self):
151
        if self._backend is None:
152
            if "backend" not in self.config:
153
                raise ConfigurationException(
154
                    "backend setting is missing", project_id=self.project_id
155
                )
156
            backend_id = self.config["backend"]
157
            try:
158
                backend_class = annif.backend.get_backend(backend_id)
159
                self._backend = backend_class(
160
                    backend_id, config_params=self.config, project=self
161
                )
162
            except ValueError:
163
                logger.warning(
164
                    "Could not create backend %s, "
165
                    "make sure you've installed optional dependencies",
166
                    backend_id,
167
                )
168
        return self._backend
169
170
    def _initialize_vocab(self):
171
        if self.vocab_spec is None:
172
            raise ConfigurationException(
173
                "vocab setting is missing", project_id=self.project_id
174
            )
175
        self._vocab, self._vocab_lang = self.registry.get_vocab(
176
            self.vocab_spec, self.language
177
        )
178
179
    @property
180
    def vocab(self):
181
        if self._vocab is None:
182
            self._initialize_vocab()
183
        return self._vocab
184
185
    @property
186
    def vocab_lang(self):
187
        if self._vocab_lang is None:
188
            self._initialize_vocab()
189
        return self._vocab_lang
190
191
    @property
192
    def subjects(self):
193
        return self.vocab.subjects
194
195
    def _get_info(self, key):
196
        try:
197
            be = self.backend
198
            if be is not None:
199
                return getattr(be, key)
200
        except AnnifException as err:
201
            logger.warning(err.format_message())
202
            return None
203
204
    @property
205
    def is_trained(self):
206
        return self._get_info("is_trained")
207
208
    @property
209
    def modification_time(self):
210
        return self._get_info("modification_time")
211
212
    def suggest(self, text, backend_params=None):
213
        """Suggest subjects the given text by passing it to the backend. Returns a
214
        list of SubjectSuggestion objects ordered by decreasing score."""
215
        if not self.is_trained:
216
            if self.is_trained is None:
217
                logger.warning("Could not get train state information.")
218
            else:
219
                raise NotInitializedException("Project is not trained.")
220
        logger.debug(
221
            'Suggesting subjects for text "%s..." (len=%d)', text[:20], len(text)
222
        )
223
        text = self.transform.transform_text(text)
224
        hits = self._suggest_with_backend(text, backend_params)
225
        logger.debug("%d hits from backend", len(hits))
226
        return hits
227
228
    def suggest_batch(self, corpus, backend_params=None):
229
        """Suggest subjects for the given documents using batches of documents in their
230
        operations when possible."""
231
        if not self.is_trained:
232
            if self.is_trained is None:
233
                logger.warning("Could not get train state information.")
234
            else:
235
                raise NotInitializedException("Project is not trained.")
236
        corpus = self.transform.transform_corpus(corpus)
237
        logger.debug(
238
            f"Suggesting subjects for a batch of {sum(1 for _ in corpus.documents)}"
239
            " documents"
240
        )
241
        return self._suggest_batch_with_backend(corpus, backend_params)
242
243
    def train(self, corpus, backend_params=None, jobs=0):
244
        """train the project using documents from a metadata source"""
245
        if corpus != "cached":
246
            corpus = self.transform.transform_corpus(corpus)
247
        if backend_params is None:
248
            backend_params = {}
249
        beparams = backend_params.get(self.backend.backend_id, {})
250
        self.backend.train(corpus, beparams, jobs)
251
252
    def learn(self, corpus, backend_params=None):
253
        """further train the project using documents from a metadata source"""
254
        if backend_params is None:
255
            backend_params = {}
256
        beparams = backend_params.get(self.backend.backend_id, {})
257
        corpus = self.transform.transform_corpus(corpus)
258
        if isinstance(self.backend, annif.backend.backend.AnnifLearningBackend):
259
            self.backend.learn(corpus, beparams)
260
        else:
261
            raise NotSupportedException(
262
                "Learning not supported by backend", project_id=self.project_id
263
            )
264
265
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
266
        """optimize the hyperparameters of the project using a validation
267
        corpus against a given metric"""
268
        if isinstance(self.backend, annif.backend.hyperopt.AnnifHyperoptBackend):
269
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
270
            return optimizer.optimize(trials, jobs, results_file)
271
272
        raise NotSupportedException(
273
            "Hyperparameter optimization not supported " "by backend",
274
            project_id=self.project_id,
275
        )
276
277
    def dump(self):
278
        """return this project as a dict"""
279
        return {
280
            "project_id": self.project_id,
281
            "name": self.name,
282
            "language": self.language,
283
            "backend": {"backend_id": self.config.get("backend")},
284
            "is_trained": self.is_trained,
285
            "modification_time": self.modification_time,
286
        }
287
288
    def remove_model_data(self):
289
        """remove the data of this project"""
290
        datadir_path = self._datadir_path
291
        if os.path.isdir(datadir_path):
292
            rmtree(datadir_path)
293
            logger.info("Removed model data for project {}.".format(self.project_id))
294
        else:
295
            logger.warning(
296
                "No model data to remove for project {}.".format(self.project_id)
297
            )
298