Passed
Push — issue631-rest-api-language-det... ( 34c253...1cd800 )
by Osma
04:27
created

annif.backend.omikuji.OmikujiBackend._suggest()   A

Complexity

Conditions 3

Size

Total Lines 15
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 13
nop 3
dl 0
loc 15
rs 9.75
c 0
b 0
f 0
1
"""Annif backend using the Omikuji classifier"""
2
3
from __future__ import annotations
4
5
import os.path
6
import shutil
7
from typing import TYPE_CHECKING, Any
8
9
import omikuji
10
11
import annif.util
12
from annif.exception import (
13
    NotInitializedException,
14
    NotSupportedException,
15
    OperationFailedException,
16
)
17
from annif.suggestion import SubjectSuggestion, SuggestionBatch
18
19
from . import backend, mixins
20
21
if TYPE_CHECKING:
22
    from scipy.sparse._csr import csr_matrix
23
24
    from annif.corpus.document import DocumentCorpus
25
26
27
class OmikujiBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
28
    """Omikuji based backend for Annif"""
29
30
    name = "omikuji"
31
32
    # defaults for uninitialized instances
33
    _model = None
34
35
    TRAIN_FILE = "omikuji-train.txt"
36
    MODEL_FILE = "omikuji-model"
37
38
    DEFAULT_PARAMETERS = {
39
        "min_df": 1,
40
        "ngram": 1,
41
        "cluster_balanced": True,
42
        "cluster_k": 2,
43
        "max_depth": 20,
44
        "collapse_every_n_layers": 0,
45
    }
46
47
    def _initialize_model(self) -> None:
48
        if self._model is None:
49
            path = os.path.join(self.datadir, self.MODEL_FILE)
50
            self.debug("loading model from {}".format(path))
51
            if os.path.exists(path):
52
                try:
53
                    self._model = omikuji.Model.load(path)
54
                except RuntimeError:
55
                    raise OperationFailedException(
56
                        "Omikuji models trained on Annif versions older than "
57
                        "0.56 cannot be loaded. Please retrain your project."
58
                    )
59
            else:
60
                raise NotInitializedException(
61
                    "model {} not found".format(path), backend_id=self.backend_id
62
                )
63
64
    def initialize(self, parallel: bool = False) -> None:
65
        self.initialize_vectorizer()
66
        self._initialize_model()
67
68
    def _create_train_file(self, veccorpus: csr_matrix, corpus: DocumentCorpus) -> None:
69
        self.info("creating train file")
70
        path = os.path.join(self.datadir, self.TRAIN_FILE)
71
        with open(path, "w", encoding="utf-8") as trainfile:
72
            # Extreme Classification Repository format header line
73
            # We don't yet know the number of samples, as some may be skipped
74
            print(
75
                "00000000",
76
                len(self.vectorizer.vocabulary_),
77
                len(self.project.subjects),
78
                file=trainfile,
79
            )
80
            n_samples = 0
81
            for doc, vector in zip(corpus.documents, veccorpus):
82
                subject_ids = [str(subject_id) for subject_id in doc.subject_set]
83
                feature_values = [
84
                    "{}:{}".format(col, vector[row, col])
85
                    for row, col in zip(*vector.nonzero())
86
                ]
87
                if not subject_ids or not feature_values:
88
                    continue  # noqa
89
                print(",".join(subject_ids), " ".join(feature_values), file=trainfile)
90
                n_samples += 1
91
            # replace the number of samples value at the beginning
92
            trainfile.seek(0)
93
            print("{:08d}".format(n_samples), end="", file=trainfile)
94
95
    def _create_model(self, params: dict[str, Any], jobs: int) -> None:
96
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
97
        model_path = os.path.join(self.datadir, self.MODEL_FILE)
98
        hyper_param = omikuji.Model.default_hyper_param()
99
100
        hyper_param.cluster_balanced = annif.util.boolean(params["cluster_balanced"])
101
        hyper_param.cluster_k = int(params["cluster_k"])
102
        hyper_param.max_depth = int(params["max_depth"])
103
        hyper_param.collapse_every_n_layers = int(params["collapse_every_n_layers"])
104
105
        self._model = omikuji.Model.train_on_data(train_path, hyper_param, jobs or None)
106
        if os.path.exists(model_path):
107
            shutil.rmtree(model_path)
108
        self._model.save(os.path.join(self.datadir, self.MODEL_FILE))
109
110
    def _train(
111
        self,
112
        corpus: DocumentCorpus,
113
        params: dict[str, Any],
114
        jobs: int = 0,
115
    ) -> None:
116
        if corpus != "cached":
117
            if corpus.is_empty():
118
                raise NotSupportedException(
119
                    "Cannot train omikuji project with no documents"
120
                )
121
            input = (doc.text for doc in corpus.documents)
122
            vecparams = {
123
                "min_df": int(params["min_df"]),
124
                "tokenizer": self.project.analyzer.tokenize_words,
125
                "ngram_range": (1, int(params["ngram"])),
126
            }
127
            veccorpus = self.create_vectorizer(input, vecparams)
128
            self._create_train_file(veccorpus, corpus)
129
        else:
130
            self.info("Reusing cached training data from previous run.")
131
        self._create_model(params, jobs)
132
133
    def _suggest_batch(
134
        self, texts: list[str], params: dict[str, Any]
135
    ) -> SuggestionBatch:
136
        vector = self.vectorizer.transform(texts)
137
        limit = int(params["limit"])
138
139
        batch_results = []
140
        for row in vector:
141
            if row.nnz == 0:  # All zero vector, empty result
142
                batch_results.append([])
143
                continue
144
            feature_values = [(col, row[0, col]) for col in row.nonzero()[1]]
145
            results = []
146
            for subj_id, score in self._model.predict(feature_values, top_k=limit):
147
                results.append(SubjectSuggestion(subject_id=subj_id, score=score))
148
            batch_results.append(results)
149
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
150