Passed
Pull Request — master (#486)
by Osma
03:35
created

annif.backend.svc.SVCBackend.initialize()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Annif backend using a SVM classifier"""
2
3
import os.path
4
import joblib
5
import numpy as np
6
from sklearn.svm import LinearSVC
7
from sklearn.calibration import CalibratedClassifierCV
8
import annif.util
9
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
10
from annif.exception import NotInitializedException, NotSupportedException
11
from . import backend
12
from . import mixins
13
14
15
class SVCBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
16
    """Support vector classifier backend for Annif"""
17
    name = "svc"
18
    needs_subject_index = True
19
20
    # defaults for uninitialized instances
21
    _model = None
22
23
    MODEL_FILE = 'svc-model.gz'
24
25
    DEFAULT_PARAMETERS = {
26
        'min_df': 1,
27
    }
28
29
    def default_params(self):
30
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
31
        params.update(self.DEFAULT_PARAMETERS)
32
        return params
33
34
    def _initialize_model(self):
35
        if self._model is None:
36
            path = os.path.join(self.datadir, self.MODEL_FILE)
37
            self.debug('loading model from {}'.format(path))
38
            if os.path.exists(path):
39
                self._model = joblib.load(path)
40
            else:
41
                raise NotInitializedException(
42
                    'model {} not found'.format(path),
43
                    backend_id=self.backend_id)
44
45
    def initialize(self):
46
        self.initialize_vectorizer()
47
        self._initialize_model()
48
49
    def _train(self, corpus, params):
50
        if corpus == 'cached':
51
            raise NotSupportedException(
52
                'SVC backend does not support reuse of cached training data.')
53
        if corpus.is_empty():
54
            raise NotSupportedException(
55
                'Cannot train SVC project with no documents')
56
        input = []
57
        classes = []
58
        for doc in corpus.documents:
59
            input.append(doc.text)
60
            classes.append(doc.uris[0])
61
        vecparams = {'min_df': int(params['min_df']),
62
                     'tokenizer': self.project.analyzer.tokenize_words}
63
        veccorpus = self.create_vectorizer(input, vecparams)
64
        svc = LinearSVC()
65
        self._model = CalibratedClassifierCV(svc)
66
        self._model.fit(veccorpus, classes)
67
        annif.util.atomic_save(self._model,
68
                               self.datadir,
69
                               self.MODEL_FILE,
70
                               method=joblib.dump)
71
72
    def _suggest(self, text, params):
73
        self.debug('Suggesting subjects for text "{}..." (len={})'.format(
74
            text[:20], len(text)))
75
        vector = self.vectorizer.transform([text])
76
        if vector.nnz == 0:  # All zero vector, empty result
77
            return ListSuggestionResult([])
78
        predictions = self._model.predict_proba(vector)[0]
79
        results = []
80
        limit = int(params['limit'])
81
        for class_id in np.argsort(predictions)[::-1][:limit]:
82
            class_uri = self._model.classes_[class_id]
83
            subject_id = self.project.subjects.by_uri(class_uri)
84
            if subject_id is not None:
85
                uri, label, notation = self.project.subjects[subject_id]
86
                results.append(SubjectSuggestion(
87
                    uri=uri,
88
                    label=label,
89
                    notation=notation,
90
                    score=predictions[class_id]))
91
        return ListSuggestionResult(results)
92