annif.backend.fasttext   A
last analyzed

Complexity

Total Complexity 24

Size/Duplication

Total Lines 181
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 132
dl 0
loc 181
rs 10
c 0
b 0
f 0
wmc 24

12 Methods

Rating   Name   Duplication   Size   Complexity  
A FastTextBackend.default_params() 0 5 1
A FastTextBackend.initialize() 0 11 3
A FastTextBackend._load_model() 0 3 1
A FastTextBackend._predict_chunks() 0 10 1
A FastTextBackend._suggest_chunks() 0 22 4
A FastTextBackend._id_to_label() 0 3 1
A FastTextBackend._create_model() 0 14 2
A FastTextBackend._label_to_subject_id() 0 3 1
A FastTextBackend._create_train_file() 0 8 1
A FastTextBackend._normalize_text() 0 2 1
A FastTextBackend._train() 0 15 3
A FastTextBackend._write_train_file() 0 11 5
1
"""Annif backend using the fastText classifier"""
2
3
from __future__ import annotations
4
5
import collections
6
import os.path
7
from typing import TYPE_CHECKING, Any
8
9
import fasttext
10
11
import annif.util
12
from annif.exception import NotInitializedException, NotSupportedException
13
from annif.suggestion import SubjectSuggestion
14
15
from . import backend, mixins
16
17
if TYPE_CHECKING:
18
    from fasttext.FastText import _FastText
19
    from numpy import ndarray
20
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
25
    """fastText backend for Annif"""
26
27
    name = "fasttext"
28
29
    FASTTEXT_PARAMS = {
30
        "lr": float,
31
        "lrUpdateRate": int,
32
        "dim": int,
33
        "ws": int,
34
        "epoch": int,
35
        "minCount": int,
36
        "neg": int,
37
        "wordNgrams": int,
38
        "loss": str,
39
        "bucket": int,
40
        "minn": int,
41
        "maxn": int,
42
        "thread": int,
43
        "t": float,
44
        "pretrainedVectors": str,
45
    }
46
47
    DEFAULT_PARAMETERS = {
48
        "dim": 100,
49
        "lr": 0.25,
50
        "epoch": 5,
51
        "loss": "hs",
52
    }
53
54
    MODEL_FILE = "fasttext-model"
55
    TRAIN_FILE = "fasttext-train.txt"
56
57
    # defaults for uninitialized instances
58
    _model = None
59
60
    def default_params(self) -> dict[str, Any]:
61
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
62
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS)
63
        params.update(self.DEFAULT_PARAMETERS)
64
        return params
65
66
    @staticmethod
67
    def _load_model(path: str) -> _FastText:
68
        return fasttext.load_model(path)
69
70
    def initialize(self, parallel: bool = False) -> None:
71
        if self._model is None:
72
            path = os.path.join(self.datadir, self.MODEL_FILE)
73
            self.debug("loading fastText model from {}".format(path))
74
            if os.path.exists(path):
75
                self._model = self._load_model(path)
76
                self.debug("loaded model {}".format(str(self._model)))
77
                self.debug("dim: {}".format(self._model.get_dimension()))
78
            else:
79
                raise NotInitializedException(
80
                    "model {} not found".format(path), backend_id=self.backend_id
81
                )
82
83
    @staticmethod
84
    def _id_to_label(subject_id: int) -> str:
85
        return "__label__{:d}".format(subject_id)
86
87
    def _label_to_subject_id(self, label: str) -> int:
88
        labelnum = label.replace("__label__", "")
89
        return int(labelnum)
90
91
    def _write_train_file(self, corpus: DocumentCorpus, filename: str) -> None:
92
        with open(filename, "w", encoding="utf-8") as trainfile:
93
            for doc in corpus.documents:
94
                text = self._normalize_text(doc.text)
95
                if text == "":
96
                    continue
97
                labels = [self._id_to_label(sid) for sid in doc.subject_set]
98
                if labels:
99
                    print(" ".join(labels), text, file=trainfile)
100
                else:
101
                    self.warning(f'no labels for document "{doc.text}"')
102
103
    def _normalize_text(self, text: str) -> str:
104
        return " ".join(self.project.analyzer.tokenize_words(text))
105
106
    def _create_train_file(
107
        self,
108
        corpus: DocumentCorpus,
109
    ) -> None:
110
        self.info("creating fastText training file")
111
112
        annif.util.atomic_save(
113
            corpus, self.datadir, self.TRAIN_FILE, method=self._write_train_file
114
        )
115
116
    def _create_model(self, params: dict[str, Any], jobs: int) -> None:
117
        self.info("creating fastText model")
118
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
119
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
120
        params = {
121
            param: self.FASTTEXT_PARAMS[param](val)
122
            for param, val in params.items()
123
            if param in self.FASTTEXT_PARAMS
124
        }
125
        if jobs != 0:  # jobs set by user to non-default value
126
            params["thread"] = jobs
127
        self.debug("Model parameters: {}".format(params))
128
        self._model = fasttext.train_supervised(trainpath, **params)
129
        self._model.save_model(modelpath)
130
131
    def _train(
132
        self,
133
        corpus: DocumentCorpus,
134
        params: dict[str, Any],
135
        jobs: int = 0,
136
    ) -> None:
137
        if corpus != "cached":
138
            if corpus.is_empty():
139
                raise NotSupportedException(
140
                    "training backend {} with no documents".format(self.backend_id)
141
                )
142
            self._create_train_file(corpus)
143
        else:
144
            self.info("Reusing cached training data from previous run.")
145
        self._create_model(params, jobs)
146
147
    def _predict_chunks(
148
        self, chunktexts: list[str], limit: int
149
    ) -> tuple[list[list[str]], list[ndarray]]:
150
        return self._model.predict(
151
            list(
152
                filter(
153
                    None, [self._normalize_text(chunktext) for chunktext in chunktexts]
154
                )
155
            ),
156
            limit,
157
        )
158
159
    def _suggest_chunks(
160
        self, chunktexts: list[str], params: dict[str, Any]
161
    ) -> list[SubjectSuggestion]:
162
        limit = int(params["limit"])
163
        chunklabels, chunkscores = self._predict_chunks(chunktexts, limit)
164
        label_scores = collections.defaultdict(float)
165
        for labels, scores in zip(chunklabels, chunkscores):
166
            for label, score in zip(labels, scores):
167
                label_scores[label] += score
168
        best_labels = sorted(
169
            [(score, label) for label, score in label_scores.items()], reverse=True
170
        )
171
172
        results = []
173
        for score, label in best_labels[:limit]:
174
            results.append(
175
                SubjectSuggestion(
176
                    subject_id=self._label_to_subject_id(label),
177
                    score=score / len(chunktexts),
178
                )
179
            )
180
        return results
181