Passed
Push — master ( c8c370...dee89b )
by Osma
03:14
created

FastTextBackend._suggest_chunks()   A

Complexity

Conditions 4

Size

Total Lines 20
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 20
rs 9.45
c 0
b 0
f 0
cc 4
nop 3
1
"""Annif backend using the fastText classifier"""
2
3
import collections
4
import os.path
5
import annif.util
6
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
7
from annif.exception import NotInitializedException
8
import fastText
9
from . import backend
10
from . import mixins
11
12
13
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
14
    """fastText backend for Annif"""
15
16
    name = "fasttext"
17
    needs_subject_index = True
18
19
    FASTTEXT_PARAMS = {
20
        'lr': float,
21
        'lrUpdateRate': int,
22
        'dim': int,
23
        'ws': int,
24
        'epoch': int,
25
        'minCount': int,
26
        'neg': int,
27
        'wordNgrams': int,
28
        'loss': str,
29
        'bucket': int,
30
        'minn': int,
31
        'maxn': int,
32
        'thread': int,
33
        't': float
34
    }
35
36
    MODEL_FILE = 'fasttext-model'
37
    TRAIN_FILE = 'fasttext-train.txt'
38
39
    # defaults for uninitialized instances
40
    _model = None
41
42
    def initialize(self):
43
        if self._model is None:
44
            path = os.path.join(self.datadir, self.MODEL_FILE)
45
            self.debug('loading fastText model from {}'.format(path))
46
            if os.path.exists(path):
47
                self._model = fastText.load_model(path)
48
                self.debug('loaded model {}'.format(str(self._model)))
49
                self.debug('dim: {}'.format(self._model.get_dimension()))
50
            else:
51
                raise NotInitializedException(
52
                    'model {} not found'.format(path),
53
                    backend_id=self.backend_id)
54
55
    @staticmethod
56
    def _id_to_label(subject_id):
57
        return "__label__{:d}".format(subject_id)
58
59
    @staticmethod
60
    def _label_to_subject(project, label):
61
        labelnum = label.replace('__label__', '')
62
        subject_id = int(labelnum)
63
        return project.subjects[subject_id]
64
65
    def _write_train_file(self, doc_subjects, filename):
66
        with open(filename, 'w', encoding='utf-8') as trainfile:
67
            for doc, subject_ids in doc_subjects.items():
68
                labels = [self._id_to_label(sid) for sid in subject_ids
69
                          if sid is not None]
70
                if labels:
71
                    print(' '.join(labels), doc, file=trainfile)
72
                else:
73
                    self.warning('no labels for document "{}"'.format(doc))
74
75
    @staticmethod
76
    def _normalize_text(project, text):
77
        return ' '.join(project.analyzer.tokenize_words(text))
78
79
    def _create_train_file(self, corpus, project):
80
        self.info('creating fastText training file')
81
82
        doc_subjects = collections.defaultdict(set)
83
84
        for doc in corpus.documents:
85
            text = self._normalize_text(project, doc.text)
86
            if text == '':
87
                continue
88
            doc_subjects[text] = [project.subjects.by_uri(uri)
89
                                  for uri in doc.uris]
90
91
        annif.util.atomic_save(doc_subjects,
92
                               self.datadir,
93
                               self.TRAIN_FILE,
94
                               method=self._write_train_file)
95
96
    def _create_model(self):
97
        self.info('creating fastText model')
98
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
99
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
100
        params = {param: self.FASTTEXT_PARAMS[param](val)
101
                  for param, val in self.params.items()
102
                  if param in self.FASTTEXT_PARAMS}
103
        self._model = fastText.train_supervised(trainpath, **params)
104
        self._model.save_model(modelpath)
105
106
    def train(self, corpus, project):
107
        self._create_train_file(corpus, project)
108
        self._create_model()
109
110
    def _predict_chunks(self, chunktexts, project, limit):
111
        return self._model.predict(list(
112
            filter(None, [self._normalize_text(project, chunktext)
113
                          for chunktext in chunktexts])), limit)
114
115
    def _suggest_chunks(self, chunktexts, project):
116
        limit = int(self.params['limit'])
117
        chunklabels, chunkscores = self._predict_chunks(
118
            chunktexts, project, limit)
119
        label_scores = collections.defaultdict(float)
120
        for labels, scores in zip(chunklabels, chunkscores):
121
            for label, score in zip(labels, scores):
122
                label_scores[label] += score
123
        best_labels = sorted([(score, label)
124
                              for label, score in label_scores.items()],
125
                             reverse=True)
126
127
        results = []
128
        for score, label in best_labels[:limit]:
129
            subject = self._label_to_subject(project, label)
130
            results.append(SubjectSuggestion(
131
                uri=subject[0],
132
                label=subject[1],
133
                score=score / len(chunktexts)))
134
        return ListSuggestionResult(results, project.subjects)
135