Passed
Push — issue703-python-3.11-support ( f59527...05d52a )
by Juho
04:06 queued 14s
created

  A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Annif backend using the Omikuji classifier"""
2
from __future__ import annotations
3
4
import os.path
5
import shutil
6
from typing import TYPE_CHECKING, Any
7
8
import omikuji
9
10
import annif.util
11
from annif.exception import (
12
    NotInitializedException,
13
    NotSupportedException,
14
    OperationFailedException,
15
)
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend, mixins
19
20
if TYPE_CHECKING:
21
    from scipy.sparse._csr import csr_matrix
22
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class OmikujiBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
27
    """Omikuji based backend for Annif"""
28
29
    name = "omikuji"
30
31
    # defaults for uninitialized instances
32
    _model = None
33
34
    TRAIN_FILE = "omikuji-train.txt"
35
    MODEL_FILE = "omikuji-model"
36
37
    DEFAULT_PARAMETERS = {
38
        "min_df": 1,
39
        "ngram": 1,
40
        "cluster_balanced": True,
41
        "cluster_k": 2,
42
        "max_depth": 20,
43
        "collapse_every_n_layers": 0,
44
    }
45
46
    def _initialize_model(self) -> None:
47
        if self._model is None:
48
            path = os.path.join(self.datadir, self.MODEL_FILE)
49
            self.debug("loading model from {}".format(path))
50
            if os.path.exists(path):
51
                try:
52
                    self._model = omikuji.Model.load(path)
53
                except RuntimeError:
54
                    raise OperationFailedException(
55
                        "Omikuji models trained on Annif versions older than "
56
                        "0.56 cannot be loaded. Please retrain your project."
57
                    )
58
            else:
59
                raise NotInitializedException(
60
                    "model {} not found".format(path), backend_id=self.backend_id
61
                )
62
63
    def initialize(self, parallel: bool = False) -> None:
64
        self.initialize_vectorizer()
65
        self._initialize_model()
66
67
    def _create_train_file(self, veccorpus: csr_matrix, corpus: DocumentCorpus) -> None:
68
        self.info("creating train file")
69
        path = os.path.join(self.datadir, self.TRAIN_FILE)
70
        with open(path, "w", encoding="utf-8") as trainfile:
71
            # Extreme Classification Repository format header line
72
            # We don't yet know the number of samples, as some may be skipped
73
            print(
74
                "00000000",
75
                len(self.vectorizer.vocabulary_),
76
                len(self.project.subjects),
77
                file=trainfile,
78
            )
79
            n_samples = 0
80
            for doc, vector in zip(corpus.documents, veccorpus):
81
                subject_ids = [str(subject_id) for subject_id in doc.subject_set]
82
                feature_values = [
83
                    "{}:{}".format(col, vector[row, col])
84
                    for row, col in zip(*vector.nonzero())
85
                ]
86
                if not subject_ids or not feature_values:
87
                    continue  # noqa
88
                print(",".join(subject_ids), " ".join(feature_values), file=trainfile)
89
                n_samples += 1
90
            # replace the number of samples value at the beginning
91
            trainfile.seek(0)
92
            print("{:08d}".format(n_samples), end="", file=trainfile)
93
94
    def _create_model(self, params: dict[str, Any], jobs: int) -> None:
95
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
96
        model_path = os.path.join(self.datadir, self.MODEL_FILE)
97
        hyper_param = omikuji.Model.default_hyper_param()
98
99
        hyper_param.cluster_balanced = annif.util.boolean(params["cluster_balanced"])
100
        hyper_param.cluster_k = int(params["cluster_k"])
101
        hyper_param.max_depth = int(params["max_depth"])
102
        hyper_param.collapse_every_n_layers = int(params["collapse_every_n_layers"])
103
104
        self._model = omikuji.Model.train_on_data(train_path, hyper_param, jobs or None)
105
        if os.path.exists(model_path):
106
            shutil.rmtree(model_path)
107
        self._model.save(os.path.join(self.datadir, self.MODEL_FILE))
108
109
    def _train(
110
        self,
111
        corpus: DocumentCorpus,
112
        params: dict[str, Any],
113
        jobs: int = 0,
114
    ) -> None:
115
        if corpus != "cached":
116
            if corpus.is_empty():
117
                raise NotSupportedException(
118
                    "Cannot train omikuji project with no documents"
119
                )
120
            input = (doc.text for doc in corpus.documents)
121
            vecparams = {
122
                "min_df": int(params["min_df"]),
123
                "tokenizer": self.project.analyzer.tokenize_words,
124
                "ngram_range": (1, int(params["ngram"])),
125
            }
126
            veccorpus = self.create_vectorizer(input, vecparams)
127
            self._create_train_file(veccorpus, corpus)
128
        else:
129
            self.info("Reusing cached training data from previous run.")
130
        self._create_model(params, jobs)
131
132
    def _suggest_batch(
133
        self, texts: list[str], params: dict[str, Any]
134
    ) -> SuggestionBatch:
135
        vector = self.vectorizer.transform(texts)
136
        limit = int(params["limit"])
137
138
        batch_results = []
139
        for row in vector:
140
            if row.nnz == 0:  # All zero vector, empty result
141
                batch_results.append([])
142
                continue
143
            feature_values = [(col, row[0, col]) for col in row.nonzero()[1]]
144
            results = []
145
            for subj_id, score in self._model.predict(feature_values, top_k=limit):
146
                results.append(SubjectSuggestion(subject_id=subj_id, score=score))
147
            batch_results.append(results)
148
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
149