Passed
Pull Request — master (#614)
by Osma
02:59
created

annif.project.AnnifProject._initialize_vocab()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 6
Ratio 100 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 6
loc 6
rs 10
c 0
b 0
f 0
1
"""Project management functionality for Annif"""
2
3
import enum
4
import os.path
5
from shutil import rmtree
6
import annif
7
import annif.transform
8
import annif.analyzer
9
import annif.corpus
10
import annif.suggestion
11
import annif.backend
12
from annif.datadir import DatadirMixin
13
from annif.exception import AnnifException, ConfigurationException, \
14
    NotSupportedException, NotInitializedException
15
16
logger = annif.logger
17
18
19
class Access(enum.IntEnum):
20
    """Enumeration of access levels for projects"""
21
    private = 1
22
    hidden = 2
23
    public = 3
24
25
26 View Code Duplication
class AnnifProject(DatadirMixin):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
27
    """Class representing the configuration of a single Annif project."""
28
29
    # defaults for uninitialized instances
30
    _transform = None
31
    _analyzer = None
32
    _backend = None
33
    _vocab = None
34
    _vocab_lang = None
35
    initialized = False
36
37
    # default values for configuration settings
38
    DEFAULT_ACCESS = 'public'
39
40
    def __init__(self, project_id, config, datadir, registry):
41
        DatadirMixin.__init__(self, datadir, 'projects', project_id)
42
        self.project_id = project_id
43
        self.name = config.get('name', project_id)
44
        self.language = config['language']
45
        self.analyzer_spec = config.get('analyzer', None)
46
        self.transform_spec = config.get('transform', 'pass')
47
        self.vocab_spec = config.get('vocab', None)
48
        self.config = config
49
        self._base_datadir = datadir
50
        self.registry = registry
51
        self._init_access()
52
53
    def _init_access(self):
54
        access = self.config.get('access', self.DEFAULT_ACCESS)
55
        try:
56
            self.access = getattr(Access, access)
57
        except AttributeError:
58
            raise ConfigurationException(
59
                "'{}' is not a valid access setting".format(access),
60
                project_id=self.project_id)
61
62
    def _initialize_analyzer(self):
63
        if not self.analyzer_spec:
64
            return  # not configured, so assume it's not needed
65
        analyzer = self.analyzer
66
        logger.debug("Project '%s': initialized analyzer: %s",
67
                     self.project_id,
68
                     str(analyzer))
69
70
    def _initialize_subjects(self):
71
        try:
72
            subjects = self.subjects
73
            logger.debug("Project '%s': initialized subjects: %s",
74
                         self.project_id,
75
                         str(subjects))
76
        except AnnifException as err:
77
            logger.warning(err.format_message())
78
79
    def _initialize_backend(self, parallel):
80
        logger.debug("Project '%s': initializing backend", self.project_id)
81
        try:
82
            if not self.backend:
83
                logger.debug("Cannot initialize backend: does not exist")
84
                return
85
            self.backend.initialize(parallel)
86
        except AnnifException as err:
87
            logger.warning(err.format_message())
88
89
    def initialize(self, parallel=False):
90
        """Initialize this project and its backend so that they are ready to
91
        be used. If parallel is True, expect that the project will be used
92
        for parallel processing."""
93
94
        if self.initialized:
95
            return
96
97
        logger.debug("Initializing project '%s'", self.project_id)
98
99
        self._initialize_analyzer()
100
        self._initialize_subjects()
101
        self._initialize_backend(parallel)
102
103
        self.initialized = True
104
105
    def _suggest_with_backend(self, text, backend_params):
106
        if backend_params is None:
107
            backend_params = {}
108
        beparams = backend_params.get(self.backend.backend_id, {})
109
        hits = self.backend.suggest(text, beparams)
110
        logger.debug(
111
            'Got %d hits from backend %s',
112
            len(hits), self.backend.backend_id)
113
        return hits
114
115
    @property
116
    def analyzer(self):
117
        if self._analyzer is None:
118
            if self.analyzer_spec:
119
                self._analyzer = annif.analyzer.get_analyzer(
120
                    self.analyzer_spec)
121
            else:
122
                raise ConfigurationException(
123
                    "analyzer setting is missing", project_id=self.project_id)
124
        return self._analyzer
125
126
    @property
127
    def transform(self):
128
        if self._transform is None:
129
            self._transform = annif.transform.get_transform(
130
                self.transform_spec, project=self)
131
        return self._transform
132
133
    @property
134
    def backend(self):
135
        if self._backend is None:
136
            if 'backend' not in self.config:
137
                raise ConfigurationException(
138
                    "backend setting is missing", project_id=self.project_id)
139
            backend_id = self.config['backend']
140
            try:
141
                backend_class = annif.backend.get_backend(backend_id)
142
                self._backend = backend_class(
143
                    backend_id, config_params=self.config,
144
                    project=self)
145
            except ValueError:
146
                logger.warning(
147
                    "Could not create backend %s, "
148
                    "make sure you've installed optional dependencies",
149
                    backend_id)
150
        return self._backend
151
152
    def _initialize_vocab(self):
153
        if self.vocab_spec is None:
154
            raise ConfigurationException("vocab setting is missing",
155
                                         project_id=self.project_id)
156
        self._vocab, self._vocab_lang = self.registry.get_vocab(
157
            self.vocab_spec, self.language)
158
159
    @property
160
    def vocab(self):
161
        if self._vocab is None:
162
            self._initialize_vocab()
163
        return self._vocab
164
165
    @property
166
    def vocab_lang(self):
167
        if self._vocab_lang is None:
168
            self._initialize_vocab()
169
        return self._vocab_lang
170
171
    @property
172
    def subjects(self):
173
        return self.vocab.subjects
174
175
    def _get_info(self, key):
176
        try:
177
            be = self.backend
178
            if be is not None:
179
                return getattr(be, key)
180
        except AnnifException as err:
181
            logger.warning(err.format_message())
182
            return None
183
184
    @property
185
    def is_trained(self):
186
        return self._get_info('is_trained')
187
188
    @property
189
    def modification_time(self):
190
        return self._get_info('modification_time')
191
192
    def suggest(self, text, backend_params=None):
193
        """Suggest subjects the given text by passing it to the backend. Returns a
194
        list of SubjectSuggestion objects ordered by decreasing score."""
195
        if not self.is_trained:
196
            if self.is_trained is None:
197
                logger.warning('Could not get train state information.')
198
            else:
199
                raise NotInitializedException('Project is not trained.')
200
        logger.debug('Suggesting subjects for text "%s..." (len=%d)',
201
                     text[:20], len(text))
202
        text = self.transform.transform_text(text)
203
        hits = self._suggest_with_backend(text, backend_params)
204
        logger.debug('%d hits from backend', len(hits))
205
        return hits
206
207
    def train(self, corpus, backend_params=None, jobs=0):
208
        """train the project using documents from a metadata source"""
209
        if corpus != 'cached':
210
            corpus = self.transform.transform_corpus(corpus)
211
        if backend_params is None:
212
            backend_params = {}
213
        beparams = backend_params.get(self.backend.backend_id, {})
214
        self.backend.train(corpus, beparams, jobs)
215
216
    def learn(self, corpus, backend_params=None):
217
        """further train the project using documents from a metadata source"""
218
        if backend_params is None:
219
            backend_params = {}
220
        beparams = backend_params.get(self.backend.backend_id, {})
221
        corpus = self.transform.transform_corpus(corpus)
222
        if isinstance(
223
                self.backend,
224
                annif.backend.backend.AnnifLearningBackend):
225
            self.backend.learn(corpus, beparams)
226
        else:
227
            raise NotSupportedException("Learning not supported by backend",
228
                                        project_id=self.project_id)
229
230
    def hyperopt(self, corpus, trials, jobs, metric, results_file):
231
        """optimize the hyperparameters of the project using a validation
232
        corpus against a given metric"""
233
        if isinstance(
234
                self.backend,
235
                annif.backend.hyperopt.AnnifHyperoptBackend):
236
            optimizer = self.backend.get_hp_optimizer(corpus, metric)
237
            return optimizer.optimize(trials, jobs, results_file)
238
239
        raise NotSupportedException(
240
            "Hyperparameter optimization not supported "
241
            "by backend", project_id=self.project_id)
242
243
    def dump(self):
244
        """return this project as a dict"""
245
        return {'project_id': self.project_id,
246
                'name': self.name,
247
                'language': self.language,
248
                'backend': {'backend_id': self.config.get('backend')},
249
                'is_trained': self.is_trained,
250
                'modification_time': self.modification_time
251
                }
252
253
    def remove_model_data(self):
254
        """remove the data of this project"""
255
        datadir_path = self._datadir_path
256
        if os.path.isdir(datadir_path):
257
            rmtree(datadir_path)
258
            logger.info('Removed model data for project {}.'
259
                        .format(self.project_id))
260
        else:
261
            logger.warning('No model data to remove for project {}.'
262
                           .format(self.project_id))
263