Passed
Pull Request — master (#604)
by Osma
02:54
created

annif.backend.fasttext.FastTextBackend._train()   A

Complexity

Conditions 3

Size

Total Lines 10
Code Lines 9

Duplication

Lines 10
Ratio 100 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 4
dl 10
loc 10
rs 9.95
c 0
b 0
f 0
1
"""Annif backend using the fastText classifier"""
2
3
import collections
4
import os.path
5
import annif.util
6
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
7
from annif.exception import NotInitializedException, NotSupportedException
8
import fasttext
9
from . import backend
10
from . import mixins
11
12
13 View Code Duplication
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
14
    """fastText backend for Annif"""
15
16
    name = "fasttext"
17
18
    FASTTEXT_PARAMS = {
19
        'lr': float,
20
        'lrUpdateRate': int,
21
        'dim': int,
22
        'ws': int,
23
        'epoch': int,
24
        'minCount': int,
25
        'neg': int,
26
        'wordNgrams': int,
27
        'loss': str,
28
        'bucket': int,
29
        'minn': int,
30
        'maxn': int,
31
        'thread': int,
32
        't': float,
33
        'pretrainedVectors': str
34
    }
35
36
    DEFAULT_PARAMETERS = {
37
        'dim': 100,
38
        'lr': 0.25,
39
        'epoch': 5,
40
        'loss': 'hs',
41
    }
42
43
    MODEL_FILE = 'fasttext-model'
44
    TRAIN_FILE = 'fasttext-train.txt'
45
46
    # defaults for uninitialized instances
47
    _model = None
48
49
    def default_params(self):
50
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
51
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS)
52
        params.update(self.DEFAULT_PARAMETERS)
53
        return params
54
55
    @staticmethod
56
    def _load_model(path):
57
        # monkey patch fasttext.FastText.eprint to avoid spurious warning
58
        # see https://github.com/facebookresearch/fastText/issues/1067
59
        orig_eprint = fasttext.FastText.eprint
60
        fasttext.FastText.eprint = lambda x: None
61
        model = fasttext.load_model(path)
62
        # restore the original eprint
63
        fasttext.FastText.eprint = orig_eprint
64
        return model
65
66
    def initialize(self, parallel=False):
67
        if self._model is None:
68
            path = os.path.join(self.datadir, self.MODEL_FILE)
69
            self.debug('loading fastText model from {}'.format(path))
70
            if os.path.exists(path):
71
                self._model = self._load_model(path)
72
                self.debug('loaded model {}'.format(str(self._model)))
73
                self.debug('dim: {}'.format(self._model.get_dimension()))
74
            else:
75
                raise NotInitializedException(
76
                    'model {} not found'.format(path),
77
                    backend_id=self.backend_id)
78
79
    @staticmethod
80
    def _id_to_label(subject_id):
81
        return "__label__{:d}".format(subject_id)
82
83
    def _label_to_subject_id(self, label):
84
        labelnum = label.replace('__label__', '')
85
        return int(labelnum)
86
87
    def _write_train_file(self, corpus, filename):
88
        with open(filename, 'w', encoding='utf-8') as trainfile:
89
            for doc in corpus.documents:
90
                text = self._normalize_text(doc.text)
91
                if text == '':
92
                    continue
93
                subject_ids = [self.project.subjects.by_uri(uri)
94
                               for uri in doc.uris]
95
                labels = [self._id_to_label(sid) for sid in subject_ids
96
                          if sid is not None]
97
                if labels:
98
                    print(' '.join(labels), text, file=trainfile)
99
                else:
100
                    self.warning(f'no labels for document "{doc.text}"')
101
102
    def _normalize_text(self, text):
103
        return ' '.join(self.project.analyzer.tokenize_words(text))
104
105
    def _create_train_file(self, corpus):
106
        self.info('creating fastText training file')
107
108
        annif.util.atomic_save(corpus,
109
                               self.datadir,
110
                               self.TRAIN_FILE,
111
                               method=self._write_train_file)
112
113
    def _create_model(self, params, jobs):
114
        self.info('creating fastText model')
115
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
116
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
117
        params = {param: self.FASTTEXT_PARAMS[param](val)
118
                  for param, val in params.items()
119
                  if param in self.FASTTEXT_PARAMS}
120
        if jobs != 0:  # jobs set by user to non-default value
121
            params['thread'] = jobs
122
        self.debug('Model parameters: {}'.format(params))
123
        self._model = fasttext.train_supervised(trainpath, **params)
124
        self._model.save_model(modelpath)
125
126
    def _train(self, corpus, params, jobs=0):
127
        if corpus != 'cached':
128
            if corpus.is_empty():
129
                raise NotSupportedException(
130
                    'training backend {} with no documents' .format(
131
                        self.backend_id))
132
            self._create_train_file(corpus)
133
        else:
134
            self.info("Reusing cached training data from previous run.")
135
        self._create_model(params, jobs)
136
137
    def _predict_chunks(self, chunktexts, limit):
138
        return self._model.predict(list(
139
            filter(None, [self._normalize_text(chunktext)
140
                          for chunktext in chunktexts])), limit)
141
142
    def _suggest_chunks(self, chunktexts, params):
143
        limit = int(params['limit'])
144
        chunklabels, chunkscores = self._predict_chunks(
145
            chunktexts, limit)
146
        label_scores = collections.defaultdict(float)
147
        for labels, scores in zip(chunklabels, chunkscores):
148
            for label, score in zip(labels, scores):
149
                label_scores[label] += score
150
        best_labels = sorted([(score, label)
151
                              for label, score in label_scores.items()],
152
                             reverse=True)
153
154
        results = []
155
        for score, label in best_labels[:limit]:
156
            results.append(SubjectSuggestion(
157
                subject_id=self._label_to_subject_id(label),
158
                score=score / len(chunktexts)))
159
        return ListSuggestionResult(results)
160