Passed
Pull Request — master (#606)
by Osma
02:57
created

  A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 6
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 6
loc 6
rs 10
c 0
b 0
f 0
1
"""Annif backend using the Omikuji classifier"""
2
3
import omikuji
4
import os.path
5
import shutil
6
import annif.util
7
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
8
from annif.exception import NotInitializedException, NotSupportedException, \
9
    OperationFailedException
10
from . import backend
11
from . import mixins
12
13
14 View Code Duplication
class OmikujiBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
15
    """Omikuji based backend for Annif"""
16
    name = "omikuji"
17
18
    # defaults for uninitialized instances
19
    _model = None
20
21
    TRAIN_FILE = 'omikuji-train.txt'
22
    MODEL_FILE = 'omikuji-model'
23
24
    DEFAULT_PARAMETERS = {
25
        'min_df': 1,
26
        'ngram': 1,
27
        'cluster_balanced': True,
28
        'cluster_k': 2,
29
        'max_depth': 20,
30
        'collapse_every_n_layers': 0,
31
    }
32
33
    def default_params(self):
34
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
35
        params.update(self.DEFAULT_PARAMETERS)
36
        return params
37
38
    def _initialize_model(self):
39
        if self._model is None:
40
            path = os.path.join(self.datadir, self.MODEL_FILE)
41
            self.debug('loading model from {}'.format(path))
42
            if os.path.exists(path):
43
                try:
44
                    self._model = omikuji.Model.load(path)
45
                except RuntimeError:
46
                    raise OperationFailedException(
47
                        "Omikuji models trained on Annif versions older than "
48
                        "0.56 cannot be loaded. Please retrain your project.")
49
            else:
50
                raise NotInitializedException(
51
                    'model {} not found'.format(path),
52
                    backend_id=self.backend_id)
53
54
    def initialize(self, parallel=False):
55
        self.initialize_vectorizer()
56
        self._initialize_model()
57
58
    def _create_train_file(self, veccorpus, corpus):
59
        self.info('creating train file')
60
        path = os.path.join(self.datadir, self.TRAIN_FILE)
61
        with open(path, 'w', encoding='utf-8') as trainfile:
62
            # Extreme Classification Repository format header line
63
            # We don't yet know the number of samples, as some may be skipped
64
            print('00000000',
65
                  len(self.vectorizer.vocabulary_),
66
                  len(self.project.subjects),
67
                  file=trainfile)
68
            n_samples = 0
69
            for doc, vector in zip(corpus.documents, veccorpus):
70
                subject_ids = [str(subject_id)
71
                               for subject_id in doc.subject_set]
72
                feature_values = ['{}:{}'.format(col, vector[row, col])
73
                                  for row, col in zip(*vector.nonzero())]
74
                if not subject_ids or not feature_values:
75
                    continue  # noqa
76
                print(','.join(subject_ids),
77
                      ' '.join(feature_values),
78
                      file=trainfile)
79
                n_samples += 1
80
            # replace the number of samples value at the beginning
81
            trainfile.seek(0)
82
            print('{:08d}'.format(n_samples), end='', file=trainfile)
83
84
    def _create_model(self, params, jobs):
85
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
86
        model_path = os.path.join(self.datadir, self.MODEL_FILE)
87
        hyper_param = omikuji.Model.default_hyper_param()
88
89
        hyper_param.cluster_balanced = annif.util.boolean(
90
            params['cluster_balanced'])
91
        hyper_param.cluster_k = int(params['cluster_k'])
92
        hyper_param.max_depth = int(params['max_depth'])
93
        hyper_param.collapse_every_n_layers = int(
94
            params['collapse_every_n_layers'])
95
96
        self._model = omikuji.Model.train_on_data(
97
            train_path, hyper_param, jobs or None)
98
        if os.path.exists(model_path):
99
            shutil.rmtree(model_path)
100
        self._model.save(os.path.join(self.datadir, self.MODEL_FILE))
101
102
    def _train(self, corpus, params, jobs=0):
103
        if corpus != 'cached':
104
            if corpus.is_empty():
105
                raise NotSupportedException(
106
                    'Cannot train omikuji project with no documents')
107
            input = (doc.text for doc in corpus.documents)
108
            vecparams = {'min_df': int(params['min_df']),
109
                         'tokenizer': self.project.analyzer.tokenize_words,
110
                         'ngram_range': (1, int(params['ngram']))}
111
            veccorpus = self.create_vectorizer(input, vecparams)
112
            self._create_train_file(veccorpus, corpus)
113
        else:
114
            self.info("Reusing cached training data from previous run.")
115
        self._create_model(params, jobs)
116
117
    def _suggest(self, text, params):
118
        self.debug('Suggesting subjects for text "{}..." (len={})'.format(
119
            text[:20], len(text)))
120
        vector = self.vectorizer.transform([text])
121
        if vector.nnz == 0:  # All zero vector, empty result
122
            return ListSuggestionResult([])
123
        feature_values = [(col, vector[row, col])
124
                          for row, col in zip(*vector.nonzero())]
125
        results = []
126
        limit = int(params['limit'])
127
        for subj_id, score in self._model.predict(feature_values, top_k=limit):
128
            results.append(SubjectSuggestion(
129
                subject_id=subj_id,
130
                score=score))
131
        return ListSuggestionResult(results)
132