Completed
Push — master ( 09c408...852165 )
by Osma
16s queued 11s
created

annif.backend.vw_base   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 113
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 19
eloc 87
dl 0
loc 113
rs 10
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A VWBaseBackend._convert_param() 0 14 4
A VWBaseBackend.initialize() 0 15 4
A VWBaseBackend._create_examples() 0 6 1
A VWBaseBackend.train() 0 4 1
A VWBaseBackend._create_model() 0 12 2
A VWBaseBackend._create_train_file() 0 7 1
A VWBaseBackend._write_train_file() 0 5 3
A VWBaseBackend.learn() 0 6 2
A VWBaseBackend._create_params() 0 9 1
1
"""Base class for Vowpal Wabbit based Annif backends"""
2
3
import abc
4
import os
5
from vowpalwabbit import pyvw
6
import annif.util
7
from annif.exception import ConfigurationException
8
from annif.exception import NotInitializedException
9
from . import backend
10
11
12
class VWBaseBackend(backend.AnnifLearningBackend, metaclass=abc.ABCMeta):
13
    """Base class for Vowpal Wabbit based Annif backends"""
14
15
    # Parameters for VW based backends
16
    # each param specifier is a pair (allowed_values, default_value)
17
    # where allowed_values is either a type or a list of allowed values
18
    # and default_value may be None, to let VW decide by itself
19
    VW_PARAMS = {}  # needs to be specified in subclasses
20
21
    MODEL_FILE = 'vw-model'
22
    TRAIN_FILE = 'vw-train.txt'
23
24
    # defaults for uninitialized instances
25
    _model = None
26
27
    def initialize(self):
28
        if self._model is None:
29
            path = os.path.join(self.datadir, self.MODEL_FILE)
30
            if not os.path.exists(path):
31
                raise NotInitializedException(
32
                    'model {} not found'.format(path),
33
                    backend_id=self.backend_id)
34
            self.debug('loading VW model from {}'.format(path))
35
            params = self._create_params({'i': path, 'quiet': True})
36
            if 'passes' in params:
37
                # don't confuse the model with passes
38
                del params['passes']
39
            self.debug("model parameters: {}".format(params))
40
            self._model = pyvw.vw(**params)
41
            self.debug('loaded model {}'.format(str(self._model)))
42
43
    def _convert_param(self, param, val):
44
        pspec, _ = self.VW_PARAMS[param]
45
        if isinstance(pspec, list):
46
            if val in pspec:
47
                return val
48
            raise ConfigurationException(
49
                "{} is not a valid value for {} (allowed: {})".format(
50
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
51
        try:
52
            return pspec(val)
53
        except ValueError:
54
            raise ConfigurationException(
55
                "The {} value {} cannot be converted to {}".format(
56
                    param, val, pspec), backend_id=self.backend_id)
57
58
    def _create_params(self, params):
59
        params = params.copy()  # don't mutate the original dict
60
        params.update({param: defaultval
61
                       for param, (_, defaultval) in self.VW_PARAMS.items()
62
                       if defaultval is not None})
63
        params.update({param: self._convert_param(param, val)
64
                       for param, val in self.params.items()
65
                       if param in self.VW_PARAMS})
66
        return params
67
68
    @staticmethod
69
    def _write_train_file(examples, filename):
70
        with open(filename, 'w', encoding='utf-8') as trainfile:
71
            for ex in examples:
72
                print(ex, file=trainfile)
73
74
    def _create_train_file(self, corpus, project):
75
        self.info('creating VW train file')
76
        examples = self._create_examples(corpus, project)
77
        annif.util.atomic_save(examples,
78
                               self.datadir,
79
                               self.TRAIN_FILE,
80
                               method=self._write_train_file)
81
82
    @abc.abstractmethod
83
    def _create_examples(self, corpus, project):
84
        """This method should be implemented by concrete backends. It
85
        should return a sequence of strings formatted according to the VW
86
        input format."""
87
        pass  # pragma: no cover
88
89
    def _create_model(self, project, initial_params={}):
90
        initial_params = initial_params.copy()  # don't mutate the original
91
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
92
        initial_params['data'] = trainpath
93
        params = self._create_params(initial_params)
94
        if params.get('passes', 1) > 1:
95
            # need a cache file when there are multiple passes
96
            params.update({'cache': True, 'kill_cache': True})
97
        self.debug("model parameters: {}".format(params))
98
        self._model = pyvw.vw(**params)
99
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
100
        self._model.save(modelpath)
101
102
    def train(self, corpus, project):
103
        self.info("creating VW model")
104
        self._create_train_file(corpus, project)
105
        self._create_model(project)
106
107
    def learn(self, corpus, project):
108
        self.initialize()
109
        for example in self._create_examples(corpus, project):
110
            self._model.learn(example)
111
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
112
        self._model.save(modelpath)
113