Passed
Pull Request — master (#366)
by Osma
03:53
created

annif.backend.omikuji   A

Complexity

Total Complexity 17

Size/Duplication

Total Lines 124
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 17
eloc 101
dl 0
loc 124
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A OmikujiBackend._uris_to_subj_ids() 0 6 1
A OmikujiBackend._initialize_model() 0 10 3
B OmikujiBackend._create_train_file() 0 24 5
A OmikujiBackend._create_model() 0 14 2
A OmikujiBackend.train() 0 10 2
A OmikujiBackend.initialize() 0 3 1
A OmikujiBackend.default_params() 0 4 1
A OmikujiBackend._suggest() 0 15 2
1
"""Annif backend using the Omikuji classifier"""
2
3
import omikuji
4
import os.path
5
import shutil
6
import annif.util
7
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
8
from annif.exception import NotInitializedException, NotSupportedException
9
from . import backend
10
from . import mixins
11
12
13
class OmikujiBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
14
    """Omikuji based backend for Annif"""
15
    name = "omikuji"
16
    needs_subject_index = True
17
18
    # defaults for uninitialized instances
19
    _model = None
20
21
    TRAIN_FILE = 'omikuji-train.txt'
22
    MODEL_FILE = 'omikuji-model'
23
24
    DEFAULT_PARAMS = {
25
        'min_df': 1,
26
        'cluster_balanced': True,
27
        'cluster_k': 2,
28
        'max_depth': 20,
29
    }
30
31
    def default_params(self):
32
        params = backend.AnnifBackend.DEFAULT_PARAMS.copy()
33
        params.update(self.DEFAULT_PARAMS)
34
        return params
35
36
    def _initialize_model(self):
37
        if self._model is None:
38
            path = os.path.join(self.datadir, self.MODEL_FILE)
39
            self.debug('loading model from {}'.format(path))
40
            if os.path.exists(path):
41
                self._model = omikuji.Model.load(path)
42
            else:
43
                raise NotInitializedException(
44
                    'model {} not found'.format(path),
45
                    backend_id=self.backend_id)
46
47
    def initialize(self):
48
        self.initialize_vectorizer()
49
        self._initialize_model()
50
51
    def _uris_to_subj_ids(self, uris):
52
        subject_ids = [self.project.subjects.by_uri(uri)
53
                       for uri in uris]
54
        return [str(subj_id)
55
                for subj_id in subject_ids
56
                if subj_id is not None]
57
58
    def _create_train_file(self, veccorpus, corpus):
59
        self.info('creating train file')
60
        path = os.path.join(self.datadir, self.TRAIN_FILE)
61
        with open(path, 'w', encoding='utf-8') as trainfile:
62
            # Extreme Classification Repository format header line
63
            # We don't yet know the number of samples, as some may be skipped
64
            print('00000000',
65
                  len(self.vectorizer.vocabulary_),
66
                  len(self.project.subjects),
67
                  file=trainfile)
68
            n_samples = 0
69
            for doc, vector in zip(corpus.documents, veccorpus):
70
                subject_ids = self._uris_to_subj_ids(doc.uris)
71
                feature_values = ['{}:{}'.format(col, vector[row, col])
72
                                  for row, col in zip(*vector.nonzero())]
73
                if not subject_ids or not feature_values:
74
                    continue  # noqa
75
                print(','.join(subject_ids),
76
                      ' '.join(feature_values),
77
                      file=trainfile)
78
                n_samples += 1
79
            # replace the number of samples value at the beginning
80
            trainfile.seek(0)
81
            print('{:08d}'.format(n_samples), end='', file=trainfile)
82
83
    def _create_model(self):
84
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
85
        model_path = os.path.join(self.datadir, self.MODEL_FILE)
86
        hyper_param = omikuji.Model.default_hyper_param()
87
88
        hyper_param.cluster_balanced = annif.util.boolean(
89
            self.params['cluster_balanced'])
90
        hyper_param.cluster_k = int(self.params['cluster_k'])
91
        hyper_param.max_depth = int(self.params['max_depth'])
92
93
        self._model = omikuji.Model.train_on_data(train_path, hyper_param)
94
        if os.path.exists(model_path):
95
            shutil.rmtree(model_path)
96
        self._model.save(os.path.join(self.datadir, self.MODEL_FILE))
97
98
    def train(self, corpus):
99
        if corpus.is_empty():
100
            raise NotSupportedException(
101
                'Cannot train omikuji project with no documents')
102
        input = (doc.text for doc in corpus.documents)
103
        params = {'min_df': int(self.params['min_df']),
104
                  'tokenizer': self.project.analyzer.tokenize_words}
105
        veccorpus = self.create_vectorizer(input, params)
106
        self._create_train_file(veccorpus, corpus)
107
        self._create_model()
108
109
    def _suggest(self, text, params):
110
        self.debug('Suggesting subjects for text "{}..." (len={})'.format(
111
            text[:20], len(text)))
112
        vector = self.vectorizer.transform([text])
113
        feature_values = [(col, vector[row, col])
114
                          for row, col in zip(*vector.nonzero())]
115
        results = []
116
        limit = int(self.params['limit'])
117
        for subj_id, score in self._model.predict(feature_values, top_k=limit):
118
            subject = self.project.subjects[subj_id]
119
            results.append(SubjectSuggestion(
120
                uri=subject[0],
121
                label=subject[1],
122
                score=score))
123
        return ListSuggestionResult(results, self.project.subjects)
124