Passed
Pull Request — master (#486)
by Osma
02:04
created

annif.backend.svc.SVCBackend.initialize()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Annif backend using a SVM classifier"""
2
3
import os.path
4
import joblib
5
import numpy as np
6
import scipy.special
7
from sklearn.svm import LinearSVC
8
import annif.util
9
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
10
from annif.exception import NotInitializedException, NotSupportedException
11
from . import backend
12
from . import mixins
13
14
15
class SVCBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
16
    """Support vector classifier backend for Annif"""
17
    name = "svc"
18
    needs_subject_index = True
19
20
    # defaults for uninitialized instances
21
    _model = None
22
23
    MODEL_FILE = 'svc-model.gz'
24
25
    DEFAULT_PARAMETERS = {
26
        'min_df': 1,
27
        'ngram': 1
28
    }
29
30
    def default_params(self):
31
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
32
        params.update(self.DEFAULT_PARAMETERS)
33
        return params
34
35
    def _initialize_model(self):
36
        if self._model is None:
37
            path = os.path.join(self.datadir, self.MODEL_FILE)
38
            self.debug('loading model from {}'.format(path))
39
            if os.path.exists(path):
40
                self._model = joblib.load(path)
41
            else:
42
                raise NotInitializedException(
43
                    'model {} not found'.format(path),
44
                    backend_id=self.backend_id)
45
46
    def initialize(self):
47
        self.initialize_vectorizer()
48
        self._initialize_model()
49
50
    def _train(self, corpus, params):
51
        if corpus == 'cached':
52
            raise NotSupportedException(
53
                'SVC backend does not support reuse of cached training data.')
54
        if corpus.is_empty():
55
            raise NotSupportedException(
56
                'Cannot train SVC project with no documents')
57
        input = []
58
        classes = []
59
        for doc in corpus.documents:
60
            input.append(doc.text)
61
            classes.append(doc.uris[0])
62
        vecparams = {'min_df': int(params['min_df']),
63
                     'tokenizer': self.project.analyzer.tokenize_words,
64
                     'ngram_range': (1, int(params['ngram']))}
65
        veccorpus = self.create_vectorizer(input, vecparams)
66
        self._model = LinearSVC()
67
        self._model.fit(veccorpus, classes)
68
        annif.util.atomic_save(self._model,
69
                               self.datadir,
70
                               self.MODEL_FILE,
71
                               method=joblib.dump)
72
73
    def _suggest(self, text, params):
74
        self.debug('Suggesting subjects for text "{}..." (len={})'.format(
75
            text[:20], len(text)))
76
        vector = self.vectorizer.transform([text])
77
        if vector.nnz == 0:  # All zero vector, empty result
78
            return ListSuggestionResult([])
79
        confidences = self._model.decision_function(vector)[0]
80
        # convert to 0..1 score range using logistic function
81
        scores = scipy.special.expit(confidences)
82
        results = []
83
        limit = int(params['limit'])
84
        for class_id in np.argsort(scores)[::-1][:limit]:
85
            class_uri = self._model.classes_[class_id]
86
            subject_id = self.project.subjects.by_uri(class_uri)
87
            if subject_id is not None:
88
                uri, label, notation = self.project.subjects[subject_id]
89
                results.append(SubjectSuggestion(
90
                    uri=uri,
91
                    label=label,
92
                    notation=notation,
93
                    score=scores[class_id]))
94
        return ListSuggestionResult(results)
95