Completed
Push — master ( cc6dfc...29e4cb )
by Osma
19s queued 13s
created

annif.backend.svc.SVCBackend._suggest_batch()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 3
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
"""Annif backend using a SVM classifier"""
2
3
import os.path
4
5
import joblib
6
import numpy as np
7
import scipy.special
8
from sklearn.svm import LinearSVC
9
10
import annif.util
11
from annif.exception import NotInitializedException, NotSupportedException
12
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
13
14
from . import backend, mixins
15
16
17
class SVCBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
18
    """Support vector classifier backend for Annif"""
19
20
    name = "svc"
21
22
    # defaults for uninitialized instances
23
    _model = None
24
25
    MODEL_FILE = "svc-model.gz"
26
27
    DEFAULT_PARAMETERS = {"min_df": 1, "ngram": 1}
28
29
    def default_params(self):
30
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
31
        params.update(self.DEFAULT_PARAMETERS)
32
        return params
33
34
    def _initialize_model(self):
35
        if self._model is None:
36
            path = os.path.join(self.datadir, self.MODEL_FILE)
37
            self.debug("loading model from {}".format(path))
38
            if os.path.exists(path):
39
                self._model = joblib.load(path)
40
            else:
41
                raise NotInitializedException(
42
                    "model {} not found".format(path), backend_id=self.backend_id
43
                )
44
45
    def initialize(self, parallel=False):
46
        self.initialize_vectorizer()
47
        self._initialize_model()
48
49
    def _corpus_to_texts_and_classes(self, corpus):
50
        texts = []
51
        classes = []
52
        for doc in corpus.documents:
53
            if len(doc.subject_set) > 1:
54
                self.warning(
55
                    "training on a document with multiple subjects is not "
56
                    + "supported by SVC; selecting one random subject."
57
                )
58
            elif not doc.subject_set:
59
                continue  # skip documents with no subjects
60
            texts.append(doc.text)
61
            classes.append(doc.subject_set[0])
62
        return texts, classes
63
64
    def _train_classifier(self, veccorpus, classes):
65
        self.info("creating classifier")
66
        self._model = LinearSVC()
67
        self._model.fit(veccorpus, classes)
68
        annif.util.atomic_save(
69
            self._model, self.datadir, self.MODEL_FILE, method=joblib.dump
70
        )
71
72
    def _train(self, corpus, params, jobs=0):
73
        if corpus == "cached":
74
            raise NotSupportedException(
75
                "SVC backend does not support reuse of cached training data."
76
            )
77
        if corpus.is_empty():
78
            raise NotSupportedException("Cannot train SVC project with no documents")
79
        texts, classes = self._corpus_to_texts_and_classes(corpus)
80
        vecparams = {
81
            "min_df": int(params["min_df"]),
82
            "tokenizer": self.project.analyzer.tokenize_words,
83
            "ngram_range": (1, int(params["ngram"])),
84
        }
85
        veccorpus = self.create_vectorizer(texts, vecparams)
86
        self._train_classifier(veccorpus, classes)
87
88
    def _scores_to_suggestions(self, scores, params):
89
        results = []
90
        limit = int(params["limit"])
91
        for class_id in np.argsort(scores)[::-1][:limit]:
92
            subject_id = self._model.classes_[class_id]
93
            if subject_id is not None:
94
                results.append(
95
                    SubjectSuggestion(subject_id=subject_id, score=scores[class_id])
96
                )
97
        return ListSuggestionResult(results)
98
99
    def _suggest_batch(self, texts, params):
100
        vector = self.vectorizer.transform(texts)
101
        confidences = self._model.decision_function(vector)
102
        # convert to 0..1 score range using logistic function
103
        scores_list = scipy.special.expit(confidences)
104
        return [
105
            ListSuggestionResult([])
106
            if row.nnz == 0
107
            else self._scores_to_suggestions(scores, params)
108
            for scores, row in zip(scores_list, vector)
109
        ]
110