Passed
Pull Request — master (#486)
by Osma
03:11
created

annif.backend.svc.SVCBackend._train()   A

Complexity

Conditions 4

Size

Total Lines 23
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 23
nop 3
dl 0
loc 23
rs 9.328
c 0
b 0
f 0
1
"""Annif backend using a SVM classifier"""
2
3
import os.path
4
import joblib
5
import numpy as np
6
import scipy.special
7
from sklearn.svm import LinearSVC
8
import annif.util
9
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
10
from annif.exception import NotInitializedException, NotSupportedException
11
from . import backend
12
from . import mixins
13
14
15
class SVCBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
16
    """Support vector classifier backend for Annif"""
17
    name = "svc"
18
    needs_subject_index = True
19
20
    # defaults for uninitialized instances
21
    _model = None
22
23
    MODEL_FILE = 'svc-model.gz'
24
25
    DEFAULT_PARAMETERS = {
26
        'min_df': 1,
27
        'ngram': 1
28
    }
29
30
    def default_params(self):
31
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
32
        params.update(self.DEFAULT_PARAMETERS)
33
        return params
34
35
    def _initialize_model(self):
36
        if self._model is None:
37
            path = os.path.join(self.datadir, self.MODEL_FILE)
38
            self.debug('loading model from {}'.format(path))
39
            if os.path.exists(path):
40
                self._model = joblib.load(path)
41
            else:
42
                raise NotInitializedException(
43
                    'model {} not found'.format(path),
44
                    backend_id=self.backend_id)
45
46
    def initialize(self):
47
        self.initialize_vectorizer()
48
        self._initialize_model()
49
50
    def _train(self, corpus, params):
51
        if corpus == 'cached':
52
            raise NotSupportedException(
53
                'SVC backend does not support reuse of cached training data.')
54
        if corpus.is_empty():
55
            raise NotSupportedException(
56
                'Cannot train SVC project with no documents')
57
        input = []
58
        classes = []
59
        for doc in corpus.documents:
60
            input.append(doc.text)
61
            classes.append(doc.uris[0])
62
        vecparams = {'min_df': int(params['min_df']),
63
                     'tokenizer': self.project.analyzer.tokenize_words,
64
                     'ngram_range': (1, int(params['ngram']))}
65
        veccorpus = self.create_vectorizer(input, vecparams)
66
        self.info('creating classifier')
67
        self._model = LinearSVC()
68
        self._model.fit(veccorpus, classes)
69
        annif.util.atomic_save(self._model,
70
                               self.datadir,
71
                               self.MODEL_FILE,
72
                               method=joblib.dump)
73
74
    def _suggest(self, text, params):
75
        self.debug('Suggesting subjects for text "{}..." (len={})'.format(
76
            text[:20], len(text)))
77
        vector = self.vectorizer.transform([text])
78
        if vector.nnz == 0:  # All zero vector, empty result
79
            return ListSuggestionResult([])
80
        confidences = self._model.decision_function(vector)[0]
81
        # convert to 0..1 score range using logistic function
82
        scores = scipy.special.expit(confidences)
83
        results = []
84
        limit = int(params['limit'])
85
        for class_id in np.argsort(scores)[::-1][:limit]:
86
            class_uri = self._model.classes_[class_id]
87
            subject_id = self.project.subjects.by_uri(class_uri)
88
            if subject_id is not None:
89
                uri, label, notation = self.project.subjects[subject_id]
90
                results.append(SubjectSuggestion(
91
                    uri=uri,
92
                    label=label,
93
                    notation=notation,
94
                    score=scores[class_id]))
95
        return ListSuggestionResult(results)
96