Passed
Pull Request — master (#664)
by Juho
03:05
created

annif.project.AnnifProject.hyperopt()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 6
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from itertools import islice
6
from shutil import rmtree
7
8
import annif
9
import annif.analyzer
10
import annif.backend
11
import annif.corpus
12
import annif.suggestion
13
import annif.transform
14
from annif.datadir import DatadirMixin
15
from annif.exception import (
16
    AnnifException,
17
    ConfigurationException,
18
    NotInitializedException,
19
    NotSupportedException,
20
)
21
22
logger = annif.logger
23
24
25
class Access(enum.IntEnum):
26
    """Enumeration of access levels for projects"""
27
28
    private = 1
29
    hidden = 2
30
    public = 3
31
32
33
class AnnifProject(DatadirMixin):
34
    """Class representing the configuration of a single Annif project."""
35
36
    # defaults for uninitialized instances
37
    _transform = None
38
    _analyzer = None
39
    _backend = None
40
    _vocab = None
41
    _vocab_lang = None
42
    initialized = False
43
44
    # default values for configuration settings
45
    DEFAULT_ACCESS = "public"
46
    MINIBATCH_SIZE = 32
47
48
    def __init__(self, project_id, config, datadir, registry):
49
        DatadirMixin.__init__(self, datadir, "projects", project_id)
50
        self.project_id = project_id
51
        self.name = config.get("name", project_id)
52
        self.language = config["language"]
53
        self.analyzer_spec = config.get("analyzer", None)
54
        self.transform_spec = config.get("transform", "pass")
55
        self.vocab_spec = config.get("vocab", None)
56
        self.config = config
57
        self._base_datadir = datadir
58
        self.registry = registry
59
        self._init_access()
60
61
    def _init_access(self):
62
        access = self.config.get("access", self.DEFAULT_ACCESS)
63
        try:
64
            self.access = getattr(Access, access)
65
        except AttributeError:
66
            raise ConfigurationException(
67
                "'{}' is not a valid access setting".format(access),
68
                project_id=self.project_id,
69
            )
70
71
    def _initialize_analyzer(self):
72
        if not self.analyzer_spec:
73
            return  # not configured, so assume it's not needed
74
        analyzer = self.analyzer
75
        logger.debug(
76
            "Project '%s': initialized analyzer: %s", self.project_id, str(analyzer)
77
        )
78
79
    def _initialize_subjects(self):
80
        try:
81
            subjects = self.subjects
82
            logger.debug(
83
                "Project '%s': initialized subjects: %s", self.project_id, str(subjects)
84
            )
85
        except AnnifException as err:
86
            logger.warning(err.format_message())
87
88
    def _initialize_backend(self, parallel):
89
        logger.debug("Project '%s': initializing backend", self.project_id)
90
        try:
91
            if not self.backend:
92
                logger.debug("Cannot initialize backend: does not exist")
93
                return
94
            self.backend.initialize(parallel)
95
        except AnnifException as err:
96
            logger.warning(err.format_message())
97
98
    def initialize(self, parallel=False):
99
        """Initialize this project and its backend so that they are ready to
100
        be used. If parallel is True, expect that the project will be used
101
        for parallel processing."""
102
103
        if self.initialized:
104
            return
105
106
        logger.debug("Initializing project '%s'", self.project_id)
107
108
        self._initialize_analyzer()
109
        self._initialize_subjects()
110
        self._initialize_backend(parallel)
111
112
        self.initialized = True
113
114
    def _suggest_with_backend(self, text, backend_params):
115
        if backend_params is None:
116
            backend_params = {}
117
        beparams = backend_params.get(self.backend.backend_id, {})
118
        hits = self.backend.suggest(text, beparams)
119
        logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id)
120
        return hits
121
122
    def _batched(self, iterable, n):
123
        # From https://docs.python.org/3/library/itertools.html#itertools-recipes
124
        it = iter(iterable)
125
        while True:
126
            batch = list(islice(it, n))
127
            if not batch:
128
                return
129
            yield batch
130
131
    def _suggest_batch_with_backend(self, corpus, backend_params):
132
        if backend_params is None:
133
            backend_params = {}
134
        beparams = backend_params.get(self.backend.backend_id, {})
135
136
        hit_sets = []
137
        for docs_minibatch in self._batched(corpus.documents, self.MINIBATCH_SIZE):
138
            texts = [doc.text for doc in docs_minibatch]
139
            hit_sets.extend(self.backend.suggest_batch(texts, beparams))
140
        logger.debug(
141
            "Got %d hit sets from backend %s", len(hit_sets), self.backend.backend_id
142
        )
143
        return hit_sets
144
145
    @property
146
    def analyzer(self):
147
        if self._analyzer is None:
148
            if self.analyzer_spec:
149
                self._analyzer = annif.analyzer.get_analyzer(self.analyzer_spec)
150
            else:
151
                raise ConfigurationException(
152
                    "analyzer setting is missing", project_id=self.project_id
153
                )
154
        return self._analyzer
155
156
    @property
157
    def transform(self):
158
        if self._transform is None:
159
            self._transform = annif.transform.get_transform(
160
                self.transform_spec, project=self
161
            )
162
        return self._transform
163
164
    @property
165
    def backend(self):
166
        if self._backend is None:
167
            if "backend" not in self.config:
168
                raise ConfigurationException(
169
                    "backend setting is missing", project_id=self.project_id
170
                )
171
            backend_id = self.config["backend"]
172
            try:
173
                backend_class = annif.backend.get_backend(backend_id)
174
                self._backend = backend_class(
175
                    backend_id, config_params=self.config, project=self
176
                )
177
            except ValueError:
178
                logger.warning(
179
                    "Could not create backend %s, "
180
                    "make sure you've installed optional dependencies",
181
                    backend_id,
182
                )
183
        return self._backend
184
185
    def _initialize_vocab(self):
186
        if self.vocab_spec is None:
187
            raise ConfigurationException(
188
                "vocab setting is missing", project_id=self.project_id
189
            )
190
        self._vocab, self._vocab_lang = self.registry.get_vocab(
191
            self.vocab_spec, self.language
192
        )
193
194
    @property
195
    def vocab(self):
196
        if self._vocab is None:
197
            self._initialize_vocab()
198
        return self._vocab
199
200
    @property
201
    def vocab_lang(self):
202
        if self._vocab_lang is None:
203
            self._initialize_vocab()
204
        return self._vocab_lang
205
206
    @property
207
    def subjects(self):
208
        return self.vocab.subjects
209
210
    def _get_info(self, key):
211
        try:
212
            be = self.backend
213
            if be is not None:
214
                return getattr(be, key)
215
        except AnnifException as err:
216
            logger.warning(err.format_message())
217
            return None
218
219
    @property
220
    def is_trained(self):
221
        return self._get_info("is_trained")
222
223
    @property
224
    def modification_time(self):
225
        return self._get_info("modification_time")
226
227
    def suggest(self, text, backend_params=None):
228
        """Suggest subjects the given text by passing it to the backend. Returns a
229
        list of SubjectSuggestion objects ordered by decreasing score."""
230
        if not self.is_trained:
231
            if self.is_trained is None:
232
                logger.warning("Could not get train state information.")
233
            else:
234
                raise NotInitializedException("Project is not trained.")
235
        logger.debug(
236
            'Suggesting subjects for text "%s..." (len=%d)', text[:20], len(text)
237
        )
238
        text = self.transform.transform_text(text)
239
        hits = self._suggest_with_backend(text, backend_params)
240
        logger.debug("%d hits from backend", len(hits))
241
        return hits
242
243
    def suggest_batch(self, corpus, backend_params=None):
244
        """Suggest subjects for the given documents using batches of documents in their
245
        operations when possible."""
246
        if not self.is_trained:
247
            if self.is_trained is None:
248
                logger.warning("Could not get train state information.")
249
            else:
250
                raise NotInitializedException("Project is not trained.")
251
        corpus = self.transform.transform_corpus(corpus)
252
        return self._suggest_batch_with_backend(corpus, backend_params)
253
254
    def train(self, corpus, backend_params=None, jobs=0):
255
        """train the project using documents from a metadata source"""
256
        if corpus != "cached":
257
            corpus = self.transform.transform_corpus(corpus)
258
        if backend_params is None:
259
            backend_params = {}
260
        beparams = backend_params.get(self.backend.backend_id, {})
261
        self.backend.train(corpus, beparams, jobs)
262
263
    def learn(self, corpus, backend_params=None):
264
        """further train the project using documents from a metadata source"""
265
        if backend_params is None:
266
            backend_params = {}
267
        beparams = backend_params.get(self.backend.backend_id, {})
268
        corpus = self.transform.transform_corpus(corpus)
269
        if isinstance(self.backend, annif.backend.backend.AnnifLearningBackend):
270
            self.backend.learn(corpus, beparams)
271
        else:
272
            raise NotSupportedException(
273
                "Learning not supported by backend", project_id=self.project_id
274
            )
275
276
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
277
        """optimize the hyperparameters of the project using a validation
278
        corpus against a given metric"""
279
        if isinstance(self.backend, annif.backend.hyperopt.AnnifHyperoptBackend):
280
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
281
            return optimizer.optimize(trials, jobs, results_file)
282
283
        raise NotSupportedException(
284
            "Hyperparameter optimization not supported " "by backend",
285
            project_id=self.project_id,
286
        )
287
288
    def dump(self):
289
        """return this project as a dict"""
290
        return {
291
            "project_id": self.project_id,
292
            "name": self.name,
293
            "language": self.language,
294
            "backend": {"backend_id": self.config.get("backend")},
295
            "is_trained": self.is_trained,
296
            "modification_time": self.modification_time,
297
        }
298
299
    def remove_model_data(self):
300
        """remove the data of this project"""
301
        datadir_path = self._datadir_path
302
        if os.path.isdir(datadir_path):
303
            rmtree(datadir_path)
304
            logger.info("Removed model data for project {}.".format(self.project_id))
305
        else:
306
            logger.warning(
307
                "No model data to remove for project {}.".format(self.project_id)
308
            )
309