Passed
Pull Request — master (#663)
by Juho
03:24
created

annif.project.AnnifProject.suggest_batch()   A

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 3
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import itertools
5
import os.path
6
from shutil import rmtree
7
8
import annif
9
import annif.analyzer
10
import annif.backend
11
import annif.corpus
12
import annif.suggestion
13
import annif.transform
14
from annif.datadir import DatadirMixin
15
from annif.exception import (
16
    AnnifException,
17
    ConfigurationException,
18
    NotInitializedException,
19
    NotSupportedException,
20
)
21
22
logger = annif.logger
23
24
25
class Access(enum.IntEnum):
26
    """Enumeration of access levels for projects"""
27
28
    private = 1
29
    hidden = 2
30
    public = 3
31
32
33
class AnnifProject(DatadirMixin):
34
    """Class representing the configuration of a single Annif project."""
35
36
    # defaults for uninitialized instances
37
    _transform = None
38
    _analyzer = None
39
    _backend = None
40
    _vocab = None
41
    _vocab_lang = None
42
    initialized = False
43
44
    # default values for configuration settings
45
    DEFAULT_ACCESS = "public"
46
    MINIBATCH_SIZE = 32
47
48
    def __init__(self, project_id, config, datadir, registry):
49
        DatadirMixin.__init__(self, datadir, "projects", project_id)
50
        self.project_id = project_id
51
        self.name = config.get("name", project_id)
52
        self.language = config["language"]
53
        self.analyzer_spec = config.get("analyzer", None)
54
        self.transform_spec = config.get("transform", "pass")
55
        self.vocab_spec = config.get("vocab", None)
56
        self.config = config
57
        self._base_datadir = datadir
58
        self.registry = registry
59
        self._init_access()
60
61
    def _init_access(self):
62
        access = self.config.get("access", self.DEFAULT_ACCESS)
63
        try:
64
            self.access = getattr(Access, access)
65
        except AttributeError:
66
            raise ConfigurationException(
67
                "'{}' is not a valid access setting".format(access),
68
                project_id=self.project_id,
69
            )
70
71
    def _initialize_analyzer(self):
72
        if not self.analyzer_spec:
73
            return  # not configured, so assume it's not needed
74
        analyzer = self.analyzer
75
        logger.debug(
76
            "Project '%s': initialized analyzer: %s", self.project_id, str(analyzer)
77
        )
78
79
    def _initialize_subjects(self):
80
        try:
81
            subjects = self.subjects
82
            logger.debug(
83
                "Project '%s': initialized subjects: %s", self.project_id, str(subjects)
84
            )
85
        except AnnifException as err:
86
            logger.warning(err.format_message())
87
88
    def _initialize_backend(self, parallel):
89
        logger.debug("Project '%s': initializing backend", self.project_id)
90
        try:
91
            if not self.backend:
92
                logger.debug("Cannot initialize backend: does not exist")
93
                return
94
            self.backend.initialize(parallel)
95
        except AnnifException as err:
96
            logger.warning(err.format_message())
97
98
    def initialize(self, parallel=False):
99
        """Initialize this project and its backend so that they are ready to
100
        be used. If parallel is True, expect that the project will be used
101
        for parallel processing."""
102
103
        if self.initialized:
104
            return
105
106
        logger.debug("Initializing project '%s'", self.project_id)
107
108
        self._initialize_analyzer()
109
        self._initialize_subjects()
110
        self._initialize_backend(parallel)
111
112
        self.initialized = True
113
114
    def _suggest_with_backend(self, text, backend_params):
115
        if backend_params is None:
116
            backend_params = {}
117
        beparams = backend_params.get(self.backend.backend_id, {})
118
        hits = self.backend.suggest(text, beparams)
119
        logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id)
120
        return hits
121
122
    def _batched(self, iterable, n):
123
        # From https://docs.python.org/3/library/itertools.html#itertools-recipes
124
        it = iter(iterable)
125
        while True:
126
            batch = list(itertools.islice(it, n))
127
            if not batch:
128
                return
129
            yield batch
130
131
    def _suggest_batch_with_backend(self, corpus, backend_params):
132
        if backend_params is None:
133
            backend_params = {}
134
        beparams = backend_params.get(self.backend.backend_id, {})
135
        for docs_minibatch in self._batched(corpus.documents, self.MINIBATCH_SIZE):
136
            texts = [doc.text for doc in docs_minibatch]
137
            yield self.backend.suggest_batch(texts, beparams)
138
139
    @property
140
    def analyzer(self):
141
        if self._analyzer is None:
142
            if self.analyzer_spec:
143
                self._analyzer = annif.analyzer.get_analyzer(self.analyzer_spec)
144
            else:
145
                raise ConfigurationException(
146
                    "analyzer setting is missing", project_id=self.project_id
147
                )
148
        return self._analyzer
149
150
    @property
151
    def transform(self):
152
        if self._transform is None:
153
            self._transform = annif.transform.get_transform(
154
                self.transform_spec, project=self
155
            )
156
        return self._transform
157
158
    @property
159
    def backend(self):
160
        if self._backend is None:
161
            if "backend" not in self.config:
162
                raise ConfigurationException(
163
                    "backend setting is missing", project_id=self.project_id
164
                )
165
            backend_id = self.config["backend"]
166
            try:
167
                backend_class = annif.backend.get_backend(backend_id)
168
                self._backend = backend_class(
169
                    backend_id, config_params=self.config, project=self
170
                )
171
            except ValueError:
172
                logger.warning(
173
                    "Could not create backend %s, "
174
                    "make sure you've installed optional dependencies",
175
                    backend_id,
176
                )
177
        return self._backend
178
179
    def _initialize_vocab(self):
180
        if self.vocab_spec is None:
181
            raise ConfigurationException(
182
                "vocab setting is missing", project_id=self.project_id
183
            )
184
        self._vocab, self._vocab_lang = self.registry.get_vocab(
185
            self.vocab_spec, self.language
186
        )
187
188
    @property
189
    def vocab(self):
190
        if self._vocab is None:
191
            self._initialize_vocab()
192
        return self._vocab
193
194
    @property
195
    def vocab_lang(self):
196
        if self._vocab_lang is None:
197
            self._initialize_vocab()
198
        return self._vocab_lang
199
200
    @property
201
    def subjects(self):
202
        return self.vocab.subjects
203
204
    def _get_info(self, key):
205
        try:
206
            be = self.backend
207
            if be is not None:
208
                return getattr(be, key)
209
        except AnnifException as err:
210
            logger.warning(err.format_message())
211
            return None
212
213
    @property
214
    def is_trained(self):
215
        return self._get_info("is_trained")
216
217
    @property
218
    def modification_time(self):
219
        return self._get_info("modification_time")
220
221
    def suggest(self, text, backend_params=None):
222
        """Suggest subjects the given text by passing it to the backend. Returns a
223
        list of SubjectSuggestion objects ordered by decreasing score."""
224
        if not self.is_trained:
225
            if self.is_trained is None:
226
                logger.warning("Could not get train state information.")
227
            else:
228
                raise NotInitializedException("Project is not trained.")
229
        logger.debug(
230
            'Suggesting subjects for text "%s..." (len=%d)', text[:20], len(text)
231
        )
232
        text = self.transform.transform_text(text)
233
        hits = self._suggest_with_backend(text, backend_params)
234
        logger.debug("%d hits from backend", len(hits))
235
        return hits
236
237
    def suggest_batch(self, corpus, backend_params=None):
238
        """Suggest subjects for the given documents using batches of documents in their
239
        operations when possible."""
240
        if not self.is_trained:
241
            if self.is_trained is None:
242
                logger.warning("Could not get train state information.")
243
            else:
244
                raise NotInitializedException("Project is not trained.")
245
        corpus = self.transform.transform_corpus(corpus)
246
        return itertools.chain.from_iterable(
247
            self._suggest_batch_with_backend(corpus, backend_params)
248
        )
249
250
    def train(self, corpus, backend_params=None, jobs=0):
251
        """train the project using documents from a metadata source"""
252
        if corpus != "cached":
253
            corpus = self.transform.transform_corpus(corpus)
254
        if backend_params is None:
255
            backend_params = {}
256
        beparams = backend_params.get(self.backend.backend_id, {})
257
        self.backend.train(corpus, beparams, jobs)
258
259
    def learn(self, corpus, backend_params=None):
260
        """further train the project using documents from a metadata source"""
261
        if backend_params is None:
262
            backend_params = {}
263
        beparams = backend_params.get(self.backend.backend_id, {})
264
        corpus = self.transform.transform_corpus(corpus)
265
        if isinstance(self.backend, annif.backend.backend.AnnifLearningBackend):
266
            self.backend.learn(corpus, beparams)
267
        else:
268
            raise NotSupportedException(
269
                "Learning not supported by backend", project_id=self.project_id
270
            )
271
272
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
273
        """optimize the hyperparameters of the project using a validation
274
        corpus against a given metric"""
275
        if isinstance(self.backend, annif.backend.hyperopt.AnnifHyperoptBackend):
276
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
277
            return optimizer.optimize(trials, jobs, results_file)
278
279
        raise NotSupportedException(
280
            "Hyperparameter optimization not supported " "by backend",
281
            project_id=self.project_id,
282
        )
283
284
    def dump(self):
285
        """return this project as a dict"""
286
        return {
287
            "project_id": self.project_id,
288
            "name": self.name,
289
            "language": self.language,
290
            "backend": {"backend_id": self.config.get("backend")},
291
            "is_trained": self.is_trained,
292
            "modification_time": self.modification_time,
293
        }
294
295
    def remove_model_data(self):
296
        """remove the data of this project"""
297
        datadir_path = self._datadir_path
298
        if os.path.isdir(datadir_path):
299
            rmtree(datadir_path)
300
            logger.info("Removed model data for project {}.".format(self.project_id))
301
        else:
302
            logger.warning(
303
                "No model data to remove for project {}.".format(self.project_id)
304
            )
305