FastTextBackend._normalize_text()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""Annif backend using the fastText classifier"""
2
3
from __future__ import annotations
4
5
import collections
6
import os.path
7
from typing import TYPE_CHECKING, Any
8
9
import fasttext
10
11
import annif.util
12
from annif.exception import NotInitializedException, NotSupportedException
13
from annif.suggestion import SubjectSuggestion
14
15
from . import backend, mixins
16
17
if TYPE_CHECKING:
18
    from fasttext.FastText import _FastText
19
    from numpy import ndarray
20
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
25
    """fastText backend for Annif"""
26
27
    name = "fasttext"
28
29
    FASTTEXT_PARAMS = {
30
        "lr": float,
31
        "lrUpdateRate": int,
32
        "dim": int,
33
        "ws": int,
34
        "epoch": int,
35
        "minCount": int,
36
        "neg": int,
37
        "wordNgrams": int,
38
        "loss": str,
39
        "bucket": int,
40
        "minn": int,
41
        "maxn": int,
42
        "thread": int,
43
        "t": float,
44
        "pretrainedVectors": str,
45
    }
46
47
    DEFAULT_PARAMETERS = {
48
        "dim": 100,
49
        "lr": 0.25,
50
        "epoch": 5,
51
        "loss": "hs",
52
    }
53
54
    MODEL_FILE = "fasttext-model"
55
    TRAIN_FILE = "fasttext-train.txt"
56
57
    # defaults for uninitialized instances
58
    _model = None
59
60
    def default_params(self) -> dict[str, Any]:
61
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
62
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS)
63
        params.update(self.DEFAULT_PARAMETERS)
64
        return params
65
66
    @staticmethod
67
    def _load_model(path: str) -> _FastText:
68
        # monkey patch fasttext.FastText.eprint to avoid spurious warning
69
        # see https://github.com/facebookresearch/fastText/issues/1067
70
        orig_eprint = fasttext.FastText.eprint
71
        fasttext.FastText.eprint = lambda x: None
72
        model = fasttext.load_model(path)
73
        # restore the original eprint
74
        fasttext.FastText.eprint = orig_eprint
75
        return model
76
77
    def initialize(self, parallel: bool = False) -> None:
78
        if self._model is None:
79
            path = os.path.join(self.datadir, self.MODEL_FILE)
80
            self.debug("loading fastText model from {}".format(path))
81
            if os.path.exists(path):
82
                self._model = self._load_model(path)
83
                self.debug("loaded model {}".format(str(self._model)))
84
                self.debug("dim: {}".format(self._model.get_dimension()))
85
            else:
86
                raise NotInitializedException(
87
                    "model {} not found".format(path), backend_id=self.backend_id
88
                )
89
90
    @staticmethod
91
    def _id_to_label(subject_id: int) -> str:
92
        return "__label__{:d}".format(subject_id)
93
94
    def _label_to_subject_id(self, label: str) -> int:
95
        labelnum = label.replace("__label__", "")
96
        return int(labelnum)
97
98
    def _write_train_file(self, corpus: DocumentCorpus, filename: str) -> None:
99
        with open(filename, "w", encoding="utf-8") as trainfile:
100
            for doc in corpus.documents:
101
                text = self._normalize_text(doc.text)
102
                if text == "":
103
                    continue
104
                labels = [self._id_to_label(sid) for sid in doc.subject_set]
105
                if labels:
106
                    print(" ".join(labels), text, file=trainfile)
107
                else:
108
                    self.warning(f'no labels for document "{doc.text}"')
109
110
    def _normalize_text(self, text: str) -> str:
111
        return " ".join(self.project.analyzer.tokenize_words(text))
112
113
    def _create_train_file(
114
        self,
115
        corpus: DocumentCorpus,
116
    ) -> None:
117
        self.info("creating fastText training file")
118
119
        annif.util.atomic_save(
120
            corpus, self.datadir, self.TRAIN_FILE, method=self._write_train_file
121
        )
122
123
    def _create_model(self, params: dict[str, Any], jobs: int) -> None:
124
        self.info("creating fastText model")
125
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
126
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
127
        params = {
128
            param: self.FASTTEXT_PARAMS[param](val)
129
            for param, val in params.items()
130
            if param in self.FASTTEXT_PARAMS
131
        }
132
        if jobs != 0:  # jobs set by user to non-default value
133
            params["thread"] = jobs
134
        self.debug("Model parameters: {}".format(params))
135
        self._model = fasttext.train_supervised(trainpath, **params)
136
        self._model.save_model(modelpath)
137
138
    def _train(
139
        self,
140
        corpus: DocumentCorpus,
141
        params: dict[str, Any],
142
        jobs: int = 0,
143
    ) -> None:
144
        if corpus != "cached":
145
            if corpus.is_empty():
146
                raise NotSupportedException(
147
                    "training backend {} with no documents".format(self.backend_id)
148
                )
149
            self._create_train_file(corpus)
150
        else:
151
            self.info("Reusing cached training data from previous run.")
152
        self._create_model(params, jobs)
153
154
    def _predict_chunks(
155
        self, chunktexts: list[str], limit: int
156
    ) -> tuple[list[list[str]], list[ndarray]]:
157
        return self._model.predict(
158
            list(
159
                filter(
160
                    None, [self._normalize_text(chunktext) for chunktext in chunktexts]
161
                )
162
            ),
163
            limit,
164
        )
165
166
    def _suggest_chunks(
167
        self, chunktexts: list[str], params: dict[str, Any]
168
    ) -> list[SubjectSuggestion]:
169
        limit = int(params["limit"])
170
        chunklabels, chunkscores = self._predict_chunks(chunktexts, limit)
171
        label_scores = collections.defaultdict(float)
172
        for labels, scores in zip(chunklabels, chunkscores):
173
            for label, score in zip(labels, scores):
174
                label_scores[label] += score
175
        best_labels = sorted(
176
            [(score, label) for label, score in label_scores.items()], reverse=True
177
        )
178
179
        results = []
180
        for score, label in best_labels[:limit]:
181
            results.append(
182
                SubjectSuggestion(
183
                    subject_id=self._label_to_subject_id(label),
184
                    score=score / len(chunktexts),
185
                )
186
            )
187
        return results
188