Passed
Pull Request — master (#663)
by Juho
03:15
created

annif.project.AnnifProject.learn()   A

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 3
dl 0
loc 11
rs 9.95
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from shutil import rmtree
6
7
import annif
8
import annif.analyzer
9
import annif.backend
10
import annif.corpus
11
import annif.suggestion
12
import annif.transform
13
from annif.datadir import DatadirMixin
14
from annif.exception import (
15
    AnnifException,
16
    ConfigurationException,
17
    NotInitializedException,
18
    NotSupportedException,
19
)
20
21
logger = annif.logger
22
23
24
class Access(enum.IntEnum):
25
    """Enumeration of access levels for projects"""
26
27
    private = 1
28
    hidden = 2
29
    public = 3
30
31
32
class AnnifProject(DatadirMixin):
33
    """Class representing the configuration of a single Annif project."""
34
35
    # defaults for uninitialized instances
36
    _transform = None
37
    _analyzer = None
38
    _backend = None
39
    _vocab = None
40
    _vocab_lang = None
41
    initialized = False
42
43
    # default values for configuration settings
44
    DEFAULT_ACCESS = "public"
45
46
    def __init__(self, project_id, config, datadir, registry):
47
        DatadirMixin.__init__(self, datadir, "projects", project_id)
48
        self.project_id = project_id
49
        self.name = config.get("name", project_id)
50
        self.language = config["language"]
51
        self.analyzer_spec = config.get("analyzer", None)
52
        self.transform_spec = config.get("transform", "pass")
53
        self.vocab_spec = config.get("vocab", None)
54
        self.config = config
55
        self._base_datadir = datadir
56
        self.registry = registry
57
        self._init_access()
58
59
    def _init_access(self):
60
        access = self.config.get("access", self.DEFAULT_ACCESS)
61
        try:
62
            self.access = getattr(Access, access)
63
        except AttributeError:
64
            raise ConfigurationException(
65
                "'{}' is not a valid access setting".format(access),
66
                project_id=self.project_id,
67
            )
68
69
    def _initialize_analyzer(self):
70
        if not self.analyzer_spec:
71
            return  # not configured, so assume it's not needed
72
        analyzer = self.analyzer
73
        logger.debug(
74
            "Project '%s': initialized analyzer: %s", self.project_id, str(analyzer)
75
        )
76
77
    def _initialize_subjects(self):
78
        try:
79
            subjects = self.subjects
80
            logger.debug(
81
                "Project '%s': initialized subjects: %s", self.project_id, str(subjects)
82
            )
83
        except AnnifException as err:
84
            logger.warning(err.format_message())
85
86
    def _initialize_backend(self, parallel):
87
        logger.debug("Project '%s': initializing backend", self.project_id)
88
        try:
89
            if not self.backend:
90
                logger.debug("Cannot initialize backend: does not exist")
91
                return
92
            self.backend.initialize(parallel)
93
        except AnnifException as err:
94
            logger.warning(err.format_message())
95
96
    def initialize(self, parallel=False):
97
        """Initialize this project and its backend so that they are ready to
98
        be used. If parallel is True, expect that the project will be used
99
        for parallel processing."""
100
101
        if self.initialized:
102
            return
103
104
        logger.debug("Initializing project '%s'", self.project_id)
105
106
        self._initialize_analyzer()
107
        self._initialize_subjects()
108
        self._initialize_backend(parallel)
109
110
        self.initialized = True
111
112
    def _suggest_with_backend(self, text, backend_params):
113
        if backend_params is None:
114
            backend_params = {}
115
        beparams = backend_params.get(self.backend.backend_id, {})
116
        hits = self.backend.suggest(text, beparams)
117
        logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id)
118
        return hits
119
120
    def _suggest_batch_with_backend(self, corpus, transform, backend_params):
121
        if backend_params is None:
122
            backend_params = {}
123
        beparams = backend_params.get(self.backend.backend_id, {})
124
        subject_sets = self.backend.suggest_batch(corpus, transform, beparams)
125
        # logger.debug("Got %d hits from backend %s", len(hits), self.backend.backend_id
126
        return subject_sets
127
128
    @property
129
    def analyzer(self):
130
        if self._analyzer is None:
131
            if self.analyzer_spec:
132
                self._analyzer = annif.analyzer.get_analyzer(self.analyzer_spec)
133
            else:
134
                raise ConfigurationException(
135
                    "analyzer setting is missing", project_id=self.project_id
136
                )
137
        return self._analyzer
138
139
    @property
140
    def transform(self):
141
        if self._transform is None:
142
            self._transform = annif.transform.get_transform(
143
                self.transform_spec, project=self
144
            )
145
        return self._transform
146
147
    @property
148
    def backend(self):
149
        if self._backend is None:
150
            if "backend" not in self.config:
151
                raise ConfigurationException(
152
                    "backend setting is missing", project_id=self.project_id
153
                )
154
            backend_id = self.config["backend"]
155
            try:
156
                backend_class = annif.backend.get_backend(backend_id)
157
                self._backend = backend_class(
158
                    backend_id, config_params=self.config, project=self
159
                )
160
            except ValueError:
161
                logger.warning(
162
                    "Could not create backend %s, "
163
                    "make sure you've installed optional dependencies",
164
                    backend_id,
165
                )
166
        return self._backend
167
168
    def _initialize_vocab(self):
169
        if self.vocab_spec is None:
170
            raise ConfigurationException(
171
                "vocab setting is missing", project_id=self.project_id
172
            )
173
        self._vocab, self._vocab_lang = self.registry.get_vocab(
174
            self.vocab_spec, self.language
175
        )
176
177
    @property
178
    def vocab(self):
179
        if self._vocab is None:
180
            self._initialize_vocab()
181
        return self._vocab
182
183
    @property
184
    def vocab_lang(self):
185
        if self._vocab_lang is None:
186
            self._initialize_vocab()
187
        return self._vocab_lang
188
189
    @property
190
    def subjects(self):
191
        return self.vocab.subjects
192
193
    def _get_info(self, key):
194
        try:
195
            be = self.backend
196
            if be is not None:
197
                return getattr(be, key)
198
        except AnnifException as err:
199
            logger.warning(err.format_message())
200
            return None
201
202
    @property
203
    def is_trained(self):
204
        return self._get_info("is_trained")
205
206
    @property
207
    def modification_time(self):
208
        return self._get_info("modification_time")
209
210
    def suggest(self, text, backend_params=None):
211
        """Suggest subjects the given text by passing it to the backend. Returns a
212
        list of SubjectSuggestion objects ordered by decreasing score."""
213
        if not self.is_trained:
214
            if self.is_trained is None:
215
                logger.warning("Could not get train state information.")
216
            else:
217
                raise NotInitializedException("Project is not trained.")
218
        logger.debug(
219
            'Suggesting subjects for text "%s..." (len=%d)', text[:20], len(text)
220
        )
221
        text = self.transform.transform_text(text)
222
        hits = self._suggest_with_backend(text, backend_params)
223
        logger.debug("%d hits from backend", len(hits))
224
        return hits
225
226
    def suggest_batch(self, corpus, backend_params=None):
227
        """Suggest subjects for the given documents using batches of documents in their
228
        operations when possible."""
229
        if not self.is_trained:
230
            if self.is_trained is None:
231
                logger.warning("Could not get train state information.")
232
            else:
233
                raise NotInitializedException("Project is not trained.")
234
        # logger.debug(f"Suggesting subjects for {sum(1 for _ in documents)} documents")
235
        hit_sets = self._suggest_batch_with_backend(
236
            corpus, self.transform, backend_params
237
        )
238
        return hit_sets
239
240
    def train(self, corpus, backend_params=None, jobs=0):
241
        """train the project using documents from a metadata source"""
242
        if corpus != "cached":
243
            corpus = self.transform.transform_corpus(corpus)
244
        if backend_params is None:
245
            backend_params = {}
246
        beparams = backend_params.get(self.backend.backend_id, {})
247
        self.backend.train(corpus, beparams, jobs)
248
249
    def learn(self, corpus, backend_params=None):
250
        """further train the project using documents from a metadata source"""
251
        if backend_params is None:
252
            backend_params = {}
253
        beparams = backend_params.get(self.backend.backend_id, {})
254
        corpus = self.transform.transform_corpus(corpus)
255
        if isinstance(self.backend, annif.backend.backend.AnnifLearningBackend):
256
            self.backend.learn(corpus, beparams)
257
        else:
258
            raise NotSupportedException(
259
                "Learning not supported by backend", project_id=self.project_id
260
            )
261
262
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
263
        """optimize the hyperparameters of the project using a validation
264
        corpus against a given metric"""
265
        if isinstance(self.backend, annif.backend.hyperopt.AnnifHyperoptBackend):
266
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
267
            return optimizer.optimize(trials, jobs, results_file)
268
269
        raise NotSupportedException(
270
            "Hyperparameter optimization not supported " "by backend",
271
            project_id=self.project_id,
272
        )
273
274
    def dump(self):
275
        """return this project as a dict"""
276
        return {
277
            "project_id": self.project_id,
278
            "name": self.name,
279
            "language": self.language,
280
            "backend": {"backend_id": self.config.get("backend")},
281
            "is_trained": self.is_trained,
282
            "modification_time": self.modification_time,
283
        }
284
285
    def remove_model_data(self):
286
        """remove the data of this project"""
287
        datadir_path = self._datadir_path
288
        if os.path.isdir(datadir_path):
289
            rmtree(datadir_path)
290
            logger.info("Removed model data for project {}.".format(self.project_id))
291
        else:
292
            logger.warning(
293
                "No model data to remove for project {}.".format(self.project_id)
294
            )
295