Passed
Pull Request — master (#604)
by Osma
07:27 queued 11s
created

FastTextBackend._create_train_file()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 7
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 7
loc 7
rs 10
c 0
b 0
f 0
1
"""Annif backend using the fastText classifier"""
2
3
import collections
4
import os.path
5
import annif.util
6
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
7
from annif.exception import NotInitializedException, NotSupportedException
8
import fasttext
9
from . import backend
10
from . import mixins
11
12
13 View Code Duplication
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
14
    """fastText backend for Annif"""
15
16
    name = "fasttext"
17
    needs_subject_index = True
18
19
    FASTTEXT_PARAMS = {
20
        'lr': float,
21
        'lrUpdateRate': int,
22
        'dim': int,
23
        'ws': int,
24
        'epoch': int,
25
        'minCount': int,
26
        'neg': int,
27
        'wordNgrams': int,
28
        'loss': str,
29
        'bucket': int,
30
        'minn': int,
31
        'maxn': int,
32
        'thread': int,
33
        't': float,
34
        'pretrainedVectors': str
35
    }
36
37
    DEFAULT_PARAMETERS = {
38
        'dim': 100,
39
        'lr': 0.25,
40
        'epoch': 5,
41
        'loss': 'hs',
42
    }
43
44
    MODEL_FILE = 'fasttext-model'
45
    TRAIN_FILE = 'fasttext-train.txt'
46
47
    # defaults for uninitialized instances
48
    _model = None
49
50
    def default_params(self):
51
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
52
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS)
53
        params.update(self.DEFAULT_PARAMETERS)
54
        return params
55
56
    @staticmethod
57
    def _load_model(path):
58
        # monkey patch fasttext.FastText.eprint to avoid spurious warning
59
        # see https://github.com/facebookresearch/fastText/issues/1067
60
        orig_eprint = fasttext.FastText.eprint
61
        fasttext.FastText.eprint = lambda x: None
62
        model = fasttext.load_model(path)
63
        # restore the original eprint
64
        fasttext.FastText.eprint = orig_eprint
65
        return model
66
67
    def initialize(self, parallel=False):
68
        if self._model is None:
69
            path = os.path.join(self.datadir, self.MODEL_FILE)
70
            self.debug('loading fastText model from {}'.format(path))
71
            if os.path.exists(path):
72
                self._model = self._load_model(path)
73
                self.debug('loaded model {}'.format(str(self._model)))
74
                self.debug('dim: {}'.format(self._model.get_dimension()))
75
            else:
76
                raise NotInitializedException(
77
                    'model {} not found'.format(path),
78
                    backend_id=self.backend_id)
79
80
    @staticmethod
81
    def _id_to_label(subject_id):
82
        return "__label__{:d}".format(subject_id)
83
84
    def _label_to_subject_id(self, label):
85
        labelnum = label.replace('__label__', '')
86
        return int(labelnum)
87
88
    def _write_train_file(self, corpus, filename):
89
        with open(filename, 'w', encoding='utf-8') as trainfile:
90
            for doc in corpus.documents:
91
                text = self._normalize_text(doc.text)
92
                if text == '':
93
                    continue
94
                subject_ids = [self.project.subjects.by_uri(uri)
95
                               for uri in doc.uris]
96
                labels = [self._id_to_label(sid) for sid in subject_ids
97
                          if sid is not None]
98
                if labels:
99
                    print(' '.join(labels), text, file=trainfile)
100
                else:
101
                    self.warning(f'no labels for document "{doc.text}"')
102
103
    def _normalize_text(self, text):
104
        return ' '.join(self.project.analyzer.tokenize_words(text))
105
106
    def _create_train_file(self, corpus):
107
        self.info('creating fastText training file')
108
109
        annif.util.atomic_save(corpus,
110
                               self.datadir,
111
                               self.TRAIN_FILE,
112
                               method=self._write_train_file)
113
114
    def _create_model(self, params, jobs):
115
        self.info('creating fastText model')
116
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
117
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
118
        params = {param: self.FASTTEXT_PARAMS[param](val)
119
                  for param, val in params.items()
120
                  if param in self.FASTTEXT_PARAMS}
121
        if jobs != 0:  # jobs set by user to non-default value
122
            params['thread'] = jobs
123
        self.debug('Model parameters: {}'.format(params))
124
        self._model = fasttext.train_supervised(trainpath, **params)
125
        self._model.save_model(modelpath)
126
127
    def _train(self, corpus, params, jobs=0):
128
        if corpus != 'cached':
129
            if corpus.is_empty():
130
                raise NotSupportedException(
131
                    'training backend {} with no documents' .format(
132
                        self.backend_id))
133
            self._create_train_file(corpus)
134
        else:
135
            self.info("Reusing cached training data from previous run.")
136
        self._create_model(params, jobs)
137
138
    def _predict_chunks(self, chunktexts, limit):
139
        return self._model.predict(list(
140
            filter(None, [self._normalize_text(chunktext)
141
                          for chunktext in chunktexts])), limit)
142
143
    def _suggest_chunks(self, chunktexts, params):
144
        limit = int(params['limit'])
145
        chunklabels, chunkscores = self._predict_chunks(
146
            chunktexts, limit)
147
        label_scores = collections.defaultdict(float)
148
        for labels, scores in zip(chunklabels, chunkscores):
149
            for label, score in zip(labels, scores):
150
                label_scores[label] += score
151
        best_labels = sorted([(score, label)
152
                              for label, score in label_scores.items()],
153
                             reverse=True)
154
155
        results = []
156
        for score, label in best_labels[:limit]:
157
            results.append(SubjectSuggestion(
158
                subject_id=self._label_to_subject_id(label),
159
                score=score / len(chunktexts)))
160
        return ListSuggestionResult(results)
161