Passed
Pull Request — master (#663)
by Juho
05:19
created

annif.project.AnnifProject.transform()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import itertools
5
import os.path
6
from shutil import rmtree
7
8
import annif
9
import annif.analyzer
10
import annif.backend
11
import annif.corpus
12
import annif.suggestion
13
import annif.transform
14
from annif.datadir import DatadirMixin
15
from annif.exception import (
16
    AnnifException,
17
    ConfigurationException,
18
    NotInitializedException,
19
    NotSupportedException,
20
)
21
22
logger = annif.logger
23
24
25
class Access(enum.IntEnum):
26
    """Enumeration of access levels for projects"""
27
28
    private = 1
29
    hidden = 2
30
    public = 3
31
32
33
class AnnifProject(DatadirMixin):
34
    """Class representing the configuration of a single Annif project."""
35
36
    # defaults for uninitialized instances
37
    _transform = None
38
    _analyzer = None
39
    _backend = None
40
    _vocab = None
41
    _vocab_lang = None
42
    initialized = False
43
44
    # default values for configuration settings
45
    DEFAULT_ACCESS = "public"
46
47
    def __init__(self, project_id, config, datadir, registry):
48
        DatadirMixin.__init__(self, datadir, "projects", project_id)
49
        self.project_id = project_id
50
        self.name = config.get("name", project_id)
51
        self.language = config["language"]
52
        self.analyzer_spec = config.get("analyzer", None)
53
        self.transform_spec = config.get("transform", "pass")
54
        self.vocab_spec = config.get("vocab", None)
55
        self.config = config
56
        self._base_datadir = datadir
57
        self.registry = registry
58
        self._init_access()
59
60
    def _init_access(self):
61
        access = self.config.get("access", self.DEFAULT_ACCESS)
62
        try:
63
            self.access = getattr(Access, access)
64
        except AttributeError:
65
            raise ConfigurationException(
66
                "'{}' is not a valid access setting".format(access),
67
                project_id=self.project_id,
68
            )
69
70
    def _initialize_analyzer(self):
71
        if not self.analyzer_spec:
72
            return  # not configured, so assume it's not needed
73
        analyzer = self.analyzer
74
        logger.debug(
75
            "Project '%s': initialized analyzer: %s", self.project_id, str(analyzer)
76
        )
77
78
    def _initialize_subjects(self):
79
        try:
80
            subjects = self.subjects
81
            logger.debug(
82
                "Project '%s': initialized subjects: %s", self.project_id, str(subjects)
83
            )
84
        except AnnifException as err:
85
            logger.warning(err.format_message())
86
87
    def _initialize_backend(self, parallel):
88
        logger.debug("Project '%s': initializing backend", self.project_id)
89
        try:
90
            if not self.backend:
91
                logger.debug("Cannot initialize backend: does not exist")
92
                return
93
            self.backend.initialize(parallel)
94
        except AnnifException as err:
95
            logger.warning(err.format_message())
96
97
    def initialize(self, parallel=False):
98
        """Initialize this project and its backend so that they are ready to
99
        be used. If parallel is True, expect that the project will be used
100
        for parallel processing."""
101
102
        if self.initialized:
103
            return
104
105
        logger.debug("Initializing project '%s'", self.project_id)
106
107
        self._initialize_analyzer()
108
        self._initialize_subjects()
109
        self._initialize_backend(parallel)
110
111
        self.initialized = True
112
113
    def _suggest_with_backend(self, text, backend_params):
114
        if backend_params is None:
115
            backend_params = {}
116
        beparams = backend_params.get(self.backend.backend_id, {})
117
        hits = self.backend.suggest(text, beparams)
118
        logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id)
119
        return hits
120
121
    def _suggest_batch_with_backend(self, texts, backend_params):
122
        if backend_params is None:
123
            backend_params = {}
124
        beparams = backend_params.get(self.backend.backend_id, {})
125
        return self.backend.suggest_batch(texts, beparams)
126
127
    @property
128
    def analyzer(self):
129
        if self._analyzer is None:
130
            if self.analyzer_spec:
131
                self._analyzer = annif.analyzer.get_analyzer(self.analyzer_spec)
132
            else:
133
                raise ConfigurationException(
134
                    "analyzer setting is missing", project_id=self.project_id
135
                )
136
        return self._analyzer
137
138
    @property
139
    def transform(self):
140
        if self._transform is None:
141
            self._transform = annif.transform.get_transform(
142
                self.transform_spec, project=self
143
            )
144
        return self._transform
145
146
    @property
147
    def backend(self):
148
        if self._backend is None:
149
            if "backend" not in self.config:
150
                raise ConfigurationException(
151
                    "backend setting is missing", project_id=self.project_id
152
                )
153
            backend_id = self.config["backend"]
154
            try:
155
                backend_class = annif.backend.get_backend(backend_id)
156
                self._backend = backend_class(
157
                    backend_id, config_params=self.config, project=self
158
                )
159
            except ValueError:
160
                logger.warning(
161
                    "Could not create backend %s, "
162
                    "make sure you've installed optional dependencies",
163
                    backend_id,
164
                )
165
        return self._backend
166
167
    def _initialize_vocab(self):
168
        if self.vocab_spec is None:
169
            raise ConfigurationException(
170
                "vocab setting is missing", project_id=self.project_id
171
            )
172
        self._vocab, self._vocab_lang = self.registry.get_vocab(
173
            self.vocab_spec, self.language
174
        )
175
176
    @property
177
    def vocab(self):
178
        if self._vocab is None:
179
            self._initialize_vocab()
180
        return self._vocab
181
182
    @property
183
    def vocab_lang(self):
184
        if self._vocab_lang is None:
185
            self._initialize_vocab()
186
        return self._vocab_lang
187
188
    @property
189
    def subjects(self):
190
        return self.vocab.subjects
191
192
    def _get_info(self, key):
193
        try:
194
            be = self.backend
195
            if be is not None:
196
                return getattr(be, key)
197
        except AnnifException as err:
198
            logger.warning(err.format_message())
199
            return None
200
201
    @property
202
    def is_trained(self):
203
        return self._get_info("is_trained")
204
205
    @property
206
    def modification_time(self):
207
        return self._get_info("modification_time")
208
209
    def suggest(self, text, backend_params=None):
210
        """Suggest subjects the given text by passing it to the backend. Returns a
211
        list of SubjectSuggestion objects ordered by decreasing score."""
212
        if not self.is_trained:
213
            if self.is_trained is None:
214
                logger.warning("Could not get train state information.")
215
            else:
216
                raise NotInitializedException("Project is not trained.")
217
        logger.debug(
218
            'Suggesting subjects for text "%s..." (len=%d)', text[:20], len(text)
219
        )
220
        text = self.transform.transform_text(text)
221
        hits = self._suggest_with_backend(text, backend_params)
222
        logger.debug("%d hits from backend", len(hits))
223
        return hits
224
225
    def suggest_corpus(self, corpus, backend_params=None):
226
        """Suggest subjects for the given documents corpus in batches of documents."""
227
        suggestions = (
228
            self.suggest_batch([doc.text for doc in doc_batch], backend_params)
229
            for doc_batch in corpus.doc_batches
230
        )
231
        return itertools.chain.from_iterable(suggestions)
232
233
    def suggest_batch(self, texts, backend_params=None):
234
        """Suggest subjects for the given documents batch."""
235
        if not self.is_trained:
236
            if self.is_trained is None:
237
                logger.warning("Could not get train state information.")
238
            else:
239
                raise NotInitializedException("Project is not trained.")
240
        texts = [self.transform.transform_text(text) for text in texts]
241
        return self._suggest_batch_with_backend(texts, backend_params)
242
243
    def train(self, corpus, backend_params=None, jobs=0):
244
        """train the project using documents from a metadata source"""
245
        if corpus != "cached":
246
            corpus = self.transform.transform_corpus(corpus)
247
        if backend_params is None:
248
            backend_params = {}
249
        beparams = backend_params.get(self.backend.backend_id, {})
250
        self.backend.train(corpus, beparams, jobs)
251
252
    def learn(self, corpus, backend_params=None):
253
        """further train the project using documents from a metadata source"""
254
        if backend_params is None:
255
            backend_params = {}
256
        beparams = backend_params.get(self.backend.backend_id, {})
257
        corpus = self.transform.transform_corpus(corpus)
258
        if isinstance(self.backend, annif.backend.backend.AnnifLearningBackend):
259
            self.backend.learn(corpus, beparams)
260
        else:
261
            raise NotSupportedException(
262
                "Learning not supported by backend", project_id=self.project_id
263
            )
264
265
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
266
        """optimize the hyperparameters of the project using a validation
267
        corpus against a given metric"""
268
        if isinstance(self.backend, annif.backend.hyperopt.AnnifHyperoptBackend):
269
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
270
            return optimizer.optimize(trials, jobs, results_file)
271
272
        raise NotSupportedException(
273
            "Hyperparameter optimization not supported " "by backend",
274
            project_id=self.project_id,
275
        )
276
277
    def dump(self):
278
        """return this project as a dict"""
279
        return {
280
            "project_id": self.project_id,
281
            "name": self.name,
282
            "language": self.language,
283
            "backend": {"backend_id": self.config.get("backend")},
284
            "is_trained": self.is_trained,
285
            "modification_time": self.modification_time,
286
        }
287
288
    def remove_model_data(self):
289
        """remove the data of this project"""
290
        datadir_path = self._datadir_path
291
        if os.path.isdir(datadir_path):
292
            rmtree(datadir_path)
293
            logger.info("Removed model data for project {}.".format(self.project_id))
294
        else:
295
            logger.warning(
296
                "No model data to remove for project {}.".format(self.project_id)
297
            )
298