Passed
Pull Request — master (#409)
by Osma
02:15
created

annif.backend.fasttext   A

Complexity

Total Complexity 25

Size/Duplication

Total Lines 167
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 25
eloc 133
dl 0
loc 167
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A FastTextBackend.default_params() 0 5 1
A FastTextBackend.initialize() 0 12 3
A FastTextBackend._load_model() 0 10 2
A FastTextBackend._predict_chunks() 0 4 1
A FastTextBackend._suggest_chunks() 0 21 4
A FastTextBackend._id_to_label() 0 3 1
A FastTextBackend._create_model() 0 10 1
A FastTextBackend._label_to_subject() 0 4 1
A FastTextBackend._create_train_file() 0 16 3
A FastTextBackend._normalize_text() 0 2 1
A FastTextBackend._train() 0 10 3
A FastTextBackend._write_train_file() 0 9 4
1
"""Annif backend using the fastText classifier"""
2
3
import collections
4
import os.path
5
import annif.util
6
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
7
from annif.exception import NotInitializedException, NotSupportedException
8
import fasttext
9
from . import backend
10
from . import mixins
11
12
13
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
14
    """fastText backend for Annif"""
15
16
    name = "fasttext"
17
    needs_subject_index = True
18
19
    FASTTEXT_PARAMS = {
20
        'lr': float,
21
        'lrUpdateRate': int,
22
        'dim': int,
23
        'ws': int,
24
        'epoch': int,
25
        'minCount': int,
26
        'neg': int,
27
        'wordNgrams': int,
28
        'loss': str,
29
        'bucket': int,
30
        'minn': int,
31
        'maxn': int,
32
        'thread': int,
33
        't': float,
34
        'pretrainedVectors': str
35
    }
36
37
    DEFAULT_PARAMETERS = {
38
        'dim': 100,
39
        'lr': 0.25,
40
        'epoch': 5,
41
        'loss': 'hs',
42
    }
43
44
    MODEL_FILE = 'fasttext-model'
45
    TRAIN_FILE = 'fasttext-train.txt'
46
47
    # defaults for uninitialized instances
48
    _model = None
49
50
    def default_params(self):
51
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
52
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS)
53
        params.update(self.DEFAULT_PARAMETERS)
54
        return params
55
56
    @staticmethod
57
    def _load_model(path):
58
        # monkey patch fasttext.FastText.eprint to avoid spurious warning
59
        # see https://github.com/facebookresearch/fastText/issues/1067
60
        orig_eprint = fasttext.FastText.eprint
61
        fasttext.FastText.eprint = lambda x:None
62
        model = fasttext.load_model(path)
63
        # restore the original eprint
64
        fasttext.FastText.eprint = orig_eprint
65
        return model
66
67
    def initialize(self):
68
        if self._model is None:
69
            path = os.path.join(self.datadir, self.MODEL_FILE)
70
            self.debug('loading fastText model from {}'.format(path))
71
            if os.path.exists(path):
72
                self._model = self._load_model(path)
73
                self.debug('loaded model {}'.format(str(self._model)))
74
                self.debug('dim: {}'.format(self._model.get_dimension()))
75
            else:
76
                raise NotInitializedException(
77
                    'model {} not found'.format(path),
78
                    backend_id=self.backend_id)
79
80
    @staticmethod
81
    def _id_to_label(subject_id):
82
        return "__label__{:d}".format(subject_id)
83
84
    def _label_to_subject(self, label):
85
        labelnum = label.replace('__label__', '')
86
        subject_id = int(labelnum)
87
        return self.project.subjects[subject_id]
88
89
    def _write_train_file(self, doc_subjects, filename):
90
        with open(filename, 'w', encoding='utf-8') as trainfile:
91
            for doc, subject_ids in doc_subjects.items():
92
                labels = [self._id_to_label(sid) for sid in subject_ids
93
                          if sid is not None]
94
                if labels:
95
                    print(' '.join(labels), doc, file=trainfile)
96
                else:
97
                    self.warning('no labels for document "{}"'.format(doc))
98
99
    def _normalize_text(self, text):
100
        return ' '.join(self.project.analyzer.tokenize_words(text))
101
102
    def _create_train_file(self, corpus):
103
        self.info('creating fastText training file')
104
105
        doc_subjects = collections.defaultdict(set)
106
107
        for doc in corpus.documents:
108
            text = self._normalize_text(doc.text)
109
            if text == '':
110
                continue
111
            doc_subjects[text] = [self.project.subjects.by_uri(uri)
112
                                  for uri in doc.uris]
113
114
        annif.util.atomic_save(doc_subjects,
115
                               self.datadir,
116
                               self.TRAIN_FILE,
117
                               method=self._write_train_file)
118
119
    def _create_model(self, params):
120
        self.info('creating fastText model')
121
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
122
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
123
        params = {param: self.FASTTEXT_PARAMS[param](val)
124
                  for param, val in params.items()
125
                  if param in self.FASTTEXT_PARAMS}
126
        self.debug('Model parameters: {}'.format(params))
127
        self._model = fasttext.train_supervised(trainpath, **params)
128
        self._model.save_model(modelpath)
129
130
    def _train(self, corpus, params):
131
        if corpus != 'cached':
132
            if corpus.is_empty():
133
                raise NotSupportedException(
134
                    'training backend {} with no documents' .format(
135
                        self.backend_id))
136
            self._create_train_file(corpus)
137
        else:
138
            self.info("Reusing cached training data from previous run.")
139
        self._create_model(params)
140
141
    def _predict_chunks(self, chunktexts, limit):
142
        return self._model.predict(list(
143
            filter(None, [self._normalize_text(chunktext)
144
                          for chunktext in chunktexts])), limit)
145
146
    def _suggest_chunks(self, chunktexts, params):
147
        limit = int(params['limit'])
148
        chunklabels, chunkscores = self._predict_chunks(
149
            chunktexts, limit)
150
        label_scores = collections.defaultdict(float)
151
        for labels, scores in zip(chunklabels, chunkscores):
152
            for label, score in zip(labels, scores):
153
                label_scores[label] += score
154
        best_labels = sorted([(score, label)
155
                              for label, score in label_scores.items()],
156
                             reverse=True)
157
158
        results = []
159
        for score, label in best_labels[:limit]:
160
            subject = self._label_to_subject(label)
161
            results.append(SubjectSuggestion(
162
                uri=subject[0],
163
                label=subject[1],
164
                notation=subject[2],
165
                score=score / len(chunktexts)))
166
        return ListSuggestionResult(results, self.project.subjects)
167