Passed
Pull Request — master (#496)
by
unknown
02:00
created

annif.project.AnnifProject.transformer()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from shutil import rmtree
6
import annif
7
import annif.transformer
8
import annif.analyzer
9
import annif.corpus
10
import annif.suggestion
11
import annif.backend
12
import annif.vocab
13
from annif.datadir import DatadirMixin
14
from annif.exception import AnnifException, ConfigurationException, \
15
    NotSupportedException, NotInitializedException
16
17
logger = annif.logger
18
19
20
class Access(enum.IntEnum):
21
    """Enumeration of access levels for projects"""
22
    private = 1
23
    hidden = 2
24
    public = 3
25
26
27
class AnnifProject(DatadirMixin):
28
    """Class representing the configuration of a single Annif project."""
29
30
    # defaults for uninitialized instances
31
    _transformer = None
32
    _analyzer = None
33
    _backend = None
34
    _vocab = None
35
    initialized = False
36
37
    # default values for configuration settings
38
    DEFAULT_ACCESS = 'public'
39
40
    def __init__(self, project_id, config, datadir, registry):
41
        DatadirMixin.__init__(self, datadir, 'projects', project_id)
42
        self.project_id = project_id
43
        self.name = config.get('name', project_id)
44
        self.language = config['language']
45
        self.analyzer_spec = config.get('analyzer', None)
46
        self.transform_spec = config.get('input_transform', None)
47
        self.vocab_id = config.get('vocab', None)
48
        self.config = config
49
        self._base_datadir = datadir
50
        self.registry = registry
51
        self._init_access()
52
53
    def _init_access(self):
54
        access = self.config.get('access', self.DEFAULT_ACCESS)
55
        try:
56
            self.access = getattr(Access, access)
57
        except AttributeError:
58
            raise ConfigurationException(
59
                "'{}' is not a valid access setting".format(access),
60
                project_id=self.project_id)
61
62
    def _initialize_analyzer(self):
63
        if not self.analyzer_spec:
64
            return  # not configured, so assume it's not needed
65
        analyzer = self.analyzer
66
        logger.debug("Project '%s': initialized analyzer: %s",
67
                     self.project_id,
68
                     str(analyzer))
69
70
    def _initialize_subjects(self):
71
        try:
72
            subjects = self.subjects
73
            logger.debug("Project '%s': initialized subjects: %s",
74
                         self.project_id,
75
                         str(subjects))
76
        except AnnifException as err:
77
            logger.warning(err.format_message())
78
79
    def _initialize_backend(self):
80
        logger.debug("Project '%s': initializing backend", self.project_id)
81
        try:
82
            if not self.backend:
83
                logger.debug("Cannot initialize backend: does not exist")
84
                return
85
            self.backend.initialize()
86
        except AnnifException as err:
87
            logger.warning(err.format_message())
88
89
    def initialize(self):
90
        """initialize this project and its backend so that they are ready to
91
        be used"""
92
93
        logger.debug("Initializing project '%s'", self.project_id)
94
95
        self._initialize_analyzer()
96
        self._initialize_subjects()
97
        self._initialize_backend()
98
99
        self.initialized = True
100
101
    def _suggest_with_backend(self, text, backend_params):
102
        if backend_params is None:
103
            backend_params = {}
104
        beparams = backend_params.get(self.backend.backend_id, {})
105
        hits = self.backend.suggest(text, beparams)
106
        logger.debug(
107
            'Got %d hits from backend %s',
108
            len(hits), self.backend.backend_id)
109
        return hits
110
111
    @property
112
    def analyzer(self):
113
        if self._analyzer is None:
114
            if self.analyzer_spec:
115
                self._analyzer = annif.analyzer.get_analyzer(
116
                    self.analyzer_spec)
117
            else:
118
                raise ConfigurationException(
119
                    "analyzer setting is missing (and needed by the backend)",
120
                    project_id=self.project_id)
121
        return self._analyzer
122
123
    @property
124
    def transformer(self):
125
        if self._transformer is None:
126
            if self.transform_spec:
127
                self._transformer = annif.transformer.get_transform(
128
                    self.transform_spec, project=self)
129
            else:
130
                self._transformer = annif.transformer.IdentityTransform(self)
131
        return self._transformer
132
133
    @property
134
    def backend(self):
135
        if self._backend is None:
136
            if 'backend' not in self.config:
137
                raise ConfigurationException(
138
                    "backend setting is missing", project_id=self.project_id)
139
            backend_id = self.config['backend']
140
            try:
141
                backend_class = annif.backend.get_backend(backend_id)
142
                self._backend = backend_class(
143
                    backend_id, config_params=self.config,
144
                    project=self)
145
            except ValueError:
146
                logger.warning(
147
                    "Could not create backend %s, "
148
                    "make sure you've installed optional dependencies",
149
                    backend_id)
150
        return self._backend
151
152
    @property
153
    def vocab(self):
154
        if self._vocab is None:
155
            if self.vocab_id is None:
156
                raise ConfigurationException("vocab setting is missing",
157
                                             project_id=self.project_id)
158
            self._vocab = annif.vocab.AnnifVocabulary(self.vocab_id,
159
                                                      self._base_datadir,
160
                                                      self.language)
161
        return self._vocab
162
163
    @property
164
    def subjects(self):
165
        return self.vocab.subjects
166
167
    def _get_info(self, key):
168
        try:
169
            be = self.backend
170
            if be is not None:
171
                return getattr(be, key)
172
        except AnnifException as err:
173
            logger.warning(err.format_message())
174
            return None
175
176
    @property
177
    def is_trained(self):
178
        return self._get_info('is_trained')
179
180
    @property
181
    def modification_time(self):
182
        return self._get_info('modification_time')
183
184
    def suggest(self, text, backend_params=None):
185
        """Suggest subjects the given text by passing it to the backend. Returns a
186
        list of SubjectSuggestion objects ordered by decreasing score."""
187
        if not self.is_trained:
188
            if self.is_trained is None:
189
                logger.warning('Could not get train state information.')
190
            else:
191
                raise NotInitializedException('Project is not trained.')
192
        logger.debug('Suggesting subjects for text "%s..." (len=%d)',
193
                     text[:20], len(text))
194
        text = self.transformer.transform_text(text)
195
        hits = self._suggest_with_backend(text, backend_params)
196
        logger.debug('%d hits from backend', len(hits))
197
        return hits
198
199
    def train(self, corpus, backend_params=None):
200
        """train the project using documents from a metadata source"""
201
        if corpus != 'cached':
202
            corpus.set_subject_index(self.subjects)
203
            corpus = self.transformer.transform_corpus(corpus)
204
        if backend_params is None:
205
            backend_params = {}
206
        beparams = backend_params.get(self.backend.backend_id, {})
207
        self.backend.train(corpus, beparams)
208
209
    def learn(self, corpus, backend_params=None):
210
        """further train the project using documents from a metadata source"""
211
        corpus.set_subject_index(self.subjects)
212
        if backend_params is None:
213
            backend_params = {}
214
        beparams = backend_params.get(self.backend.backend_id, {})
215
        corpus = self.transformer.transform_corpus(corpus)
216
        if isinstance(
217
                self.backend,
218
                annif.backend.backend.AnnifLearningBackend):
219
            self.backend.learn(corpus, beparams)
220
        else:
221
            raise NotSupportedException("Learning not supported by backend",
222
                                        project_id=self.project_id)
223
224
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
225
        """optimize the hyperparameters of the project using a validation
226
        corpus against a given metric"""
227
        if isinstance(
228
                self.backend,
229
                annif.backend.hyperopt.AnnifHyperoptBackend):
230
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
231
            return optimizer.optimize(trials, jobs, results_file)
232
233
        raise NotSupportedException(
234
            "Hyperparameter optimization not supported "
235
            "by backend", project_id=self.project_id)
236
237
    def dump(self):
238
        """return this project as a dict"""
239
        return {'project_id': self.project_id,
240
                'name': self.name,
241
                'language': self.language,
242
                'backend': {'backend_id': self.config.get('backend')},
243
                'is_trained': self.is_trained,
244
                'modification_time': self.modification_time
245
                }
246
247
    def remove_model_data(self):
248
        """remove the data of this project"""
249
        datadir_path = self._datadir_path
250
        if os.path.isdir(datadir_path):
251
            rmtree(datadir_path)
252
            logger.info('Removed model data for project {}.'
253
                        .format(self.project_id))
254
        else:
255
            logger.warning('No model data to remove for project {}.'
256
                           .format(self.project_id))
257