Passed
Pull Request — master (#496)
by
unknown
03:24
created

annif.project.AnnifProject.transformer()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from shutil import rmtree
6
import annif
7
import annif.transformer
8
import annif.analyzer
9
import annif.corpus
10
import annif.suggestion
11
import annif.backend
12
import annif.vocab
13
from annif.datadir import DatadirMixin
14
from annif.exception import AnnifException, ConfigurationException, \
15
    NotSupportedException, NotInitializedException
16
17
logger = annif.logger
18
19
20
class Access(enum.IntEnum):
21
    """Enumeration of access levels for projects"""
22
    private = 1
23
    hidden = 2
24
    public = 3
25
26
27
class AnnifProject(DatadirMixin):
28
    """Class representing the configuration of a single Annif project."""
29
30
    # defaults for uninitialized instances
31
    _transformer = None
32
    _analyzer = None
33
    _backend = None
34
    _vocab = None
35
    initialized = False
36
37
    # default values for configuration settings
38
    DEFAULT_ACCESS = 'public'
39
40
    def __init__(self, project_id, config, datadir, registry):
41
        DatadirMixin.__init__(self, datadir, 'projects', project_id)
42
        self.project_id = project_id
43
        self.name = config.get('name', project_id)
44
        self.language = config['language']
45
        self.analyzer_spec = config.get('analyzer', None)
46
        self.transformer_spec = config.get('input_transform', None)
47
        self.vocab_id = config.get('vocab', None)
48
        self.config = config
49
        self._base_datadir = datadir
50
        self.registry = registry
51
        self._init_access()
52
53
    def _init_access(self):
54
        access = self.config.get('access', self.DEFAULT_ACCESS)
55
        try:
56
            self.access = getattr(Access, access)
57
        except AttributeError:
58
            raise ConfigurationException(
59
                "'{}' is not a valid access setting".format(access),
60
                project_id=self.project_id)
61
62
    def _initialize_analyzer(self):
63
        if not self.analyzer_spec:
64
            return  # not configured, so assume it's not needed
65
        analyzer = self.analyzer
66
        logger.debug("Project '%s': initialized analyzer: %s",
67
                     self.project_id,
68
                     str(analyzer))
69
70
    def _initialize_transformer(self):
71
        # TODO: Is this needed?
72
        if not self.transformer_spec:
73
            return  # not configured, so assume it's not needed
74
        transformer = self.transformer
75
        logger.debug("Project '%s': initialized input-transform: %s",
76
                     self.project_id,
77
                     str(transformer))
78
79
    def _initialize_subjects(self):
80
        try:
81
            subjects = self.subjects
82
            logger.debug("Project '%s': initialized subjects: %s",
83
                         self.project_id,
84
                         str(subjects))
85
        except AnnifException as err:
86
            logger.warning(err.format_message())
87
88
    def _initialize_backend(self):
89
        logger.debug("Project '%s': initializing backend", self.project_id)
90
        try:
91
            if not self.backend:
92
                logger.debug("Cannot initialize backend: does not exist")
93
                return
94
            self.backend.initialize()
95
        except AnnifException as err:
96
            logger.warning(err.format_message())
97
98
    def initialize(self):
99
        """initialize this project and its backend so that they are ready to
100
        be used"""
101
102
        logger.debug("Initializing project '%s'", self.project_id)
103
104
        self._initialize_analyzer()
105
        self._initialize_transformer()  # TODO: Is this needed?
106
        self._initialize_subjects()
107
        self._initialize_backend()
108
109
        self.initialized = True
110
111
    def _suggest_with_backend(self, text, backend_params):
112
        if backend_params is None:
113
            backend_params = {}
114
        beparams = backend_params.get(self.backend.backend_id, {})
115
        hits = self.backend.suggest(text, beparams)
116
        logger.debug(
117
            'Got %d hits from backend %s',
118
            len(hits), self.backend.backend_id)
119
        return hits
120
121
    @property
122
    def analyzer(self):
123
        if self._analyzer is None:
124
            if self.analyzer_spec:
125
                self._analyzer = annif.analyzer.get_analyzer(
126
                    self.analyzer_spec)
127
            else:
128
                raise ConfigurationException(
129
                    "analyzer setting is missing (and needed by the backend)",
130
                    project_id=self.project_id)
131
        return self._analyzer
132
133
    @property
134
    def transformer(self):
135
        if self._transformer is None:
136
            if self.transformer_spec:
137
                self._transformer = annif.transformer.get_transformer(
138
                    self.transformer_spec, project=self)
139
            else:
140
                self._transformer = None
141
        return self._transformer
142
143
    @property
144
    def backend(self):
145
        if self._backend is None:
146
            if 'backend' not in self.config:
147
                raise ConfigurationException(
148
                    "backend setting is missing", project_id=self.project_id)
149
            backend_id = self.config['backend']
150
            try:
151
                backend_class = annif.backend.get_backend(backend_id)
152
                self._backend = backend_class(
153
                    backend_id, config_params=self.config,
154
                    project=self)
155
            except ValueError:
156
                logger.warning(
157
                    "Could not create backend %s, "
158
                    "make sure you've installed optional dependencies",
159
                    backend_id)
160
        return self._backend
161
162
    @property
163
    def vocab(self):
164
        if self._vocab is None:
165
            if self.vocab_id is None:
166
                raise ConfigurationException("vocab setting is missing",
167
                                             project_id=self.project_id)
168
            self._vocab = annif.vocab.AnnifVocabulary(self.vocab_id,
169
                                                      self._base_datadir,
170
                                                      self.language)
171
        return self._vocab
172
173
    @property
174
    def subjects(self):
175
        return self.vocab.subjects
176
177
    def _get_info(self, key):
178
        try:
179
            be = self.backend
180
            if be is not None:
181
                return getattr(be, key)
182
        except AnnifException as err:
183
            logger.warning(err.format_message())
184
            return None
185
186
    @property
187
    def is_trained(self):
188
        return self._get_info('is_trained')
189
190
    @property
191
    def modification_time(self):
192
        return self._get_info('modification_time')
193
194
    def suggest(self, text, backend_params=None):
195
        """Suggest subjects the given text by passing it to the backend. Returns a
196
        list of SubjectSuggestion objects ordered by decreasing score."""
197
        if not self.is_trained:
198
            if self.is_trained is None:
199
                logger.warning('Could not get train state information.')
200
            else:
201
                raise NotInitializedException('Project is not trained.')
202
        logger.debug('Suggesting subjects for text "%s..." (len=%d)',
203
                     text[:20], len(text))
204
        if self.transformer is not None:
205
            text = self.transformer.transform_text(text)
206
        hits = self._suggest_with_backend(text, backend_params)
207
        logger.debug('%d hits from backend', len(hits))
208
        return hits
209
210
    def train(self, corpus, backend_params=None):
211
        """train the project using documents from a metadata source"""
212
        if corpus != 'cached':
213
            corpus.set_subject_index(self.subjects)
214
        if backend_params is None:
215
            backend_params = {}
216
        if self.transformer is not None:
217
            corpus = self.transformer.transform_corpus(corpus)
218
        beparams = backend_params.get(self.backend.backend_id, {})
219
        self.backend.train(corpus, beparams)
220
221
    def learn(self, corpus, backend_params=None):
222
        """further train the project using documents from a metadata source"""
223
        corpus.set_subject_index(self.subjects)
224
        if backend_params is None:
225
            backend_params = {}
226
        beparams = backend_params.get(self.backend.backend_id, {})
227
        if self.transformer is not None:
228
            corpus = self.transformer.transform_corpus(corpus)
229
        if isinstance(
230
                self.backend,
231
                annif.backend.backend.AnnifLearningBackend):
232
            self.backend.learn(corpus, beparams)
233
        else:
234
            raise NotSupportedException("Learning not supported by backend",
235
                                        project_id=self.project_id)
236
237
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
238
        """optimize the hyperparameters of the project using a validation
239
        corpus against a given metric"""
240
        if isinstance(
241
                self.backend,
242
                annif.backend.hyperopt.AnnifHyperoptBackend):
243
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
244
            return optimizer.optimize(trials, jobs, results_file)
245
246
        raise NotSupportedException(
247
            "Hyperparameter optimization not supported "
248
            "by backend", project_id=self.project_id)
249
250
    def dump(self):
251
        """return this project as a dict"""
252
        return {'project_id': self.project_id,
253
                'name': self.name,
254
                'language': self.language,
255
                'backend': {'backend_id': self.config.get('backend')},
256
                'is_trained': self.is_trained,
257
                'modification_time': self.modification_time
258
                }
259
260
    def remove_model_data(self):
261
        """remove the data of this project"""
262
        datadir_path = self._datadir_path
263
        if os.path.isdir(datadir_path):
264
            rmtree(datadir_path)
265
            logger.info('Removed model data for project {}.'
266
                        .format(self.project_id))
267
        else:
268
            logger.warning('No model data to remove for project {}.'
269
                           .format(self.project_id))
270