Completed
Push — master ( ac148f...299d84 )
by Osma
06:16 queued 25s
created

FastTextBackend.load_corpus()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Annif backend using the fastText classifier"""
2
3
import collections
4
import os.path
5
import annif.util
6
from annif.hit import AnalysisHit, ListAnalysisResult
7
from annif.exception import NotInitializedException
8
import fastText
9
from . import backend
10
11
12
class FastTextBackend(backend.AnnifBackend):
13
    """fastText backend for Annif"""
14
15
    name = "fasttext"
16
    needs_subject_index = True
17
18
    FASTTEXT_PARAMS = {
19
        'lr': float,
20
        'lrUpdateRate': int,
21
        'dim': int,
22
        'ws': int,
23
        'epoch': int,
24
        'minCount': int,
25
        'neg': int,
26
        'wordNgrams': int,
27
        'loss': str,
28
        'bucket': int,
29
        'minn': int,
30
        'maxn': int,
31
        'thread': int,
32
        't': float
33
    }
34
35
    MODEL_FILE = 'fasttext-model'
36
    TRAIN_FILE = 'fasttext-train.txt'
37
38
    # defaults for uninitialized instances
39
    _model = None
40
41
    def initialize(self):
42
        if self._model is None:
43
            path = os.path.join(self._get_datadir(), self.MODEL_FILE)
44
            self.debug('loading fastText model from {}'.format(path))
45
            if os.path.exists(path):
46
                self._model = fastText.load_model(path)
47
                self.debug('loaded model {}'.format(str(self._model)))
48
                self.debug('dim: {}'.format(self._model.get_dimension()))
49
            else:
50
                raise NotInitializedException(
51
                    'model {} not found'.format(path),
52
                    backend_id=self.backend_id)
53
54
    @classmethod
55
    def _id_to_label(cls, subject_id):
56
        return "__label__{:d}".format(subject_id)
57
58
    @classmethod
59
    def _label_to_subject(cls, project, label):
60
        labelnum = label.replace('__label__', '')
61
        subject_id = int(labelnum)
62
        return project.subjects[subject_id]
63
64
    def _write_train_file(self, doc_subjects, filename):
65
        with open(filename, 'w') as trainfile:
66
            for doc, subject_ids in doc_subjects.items():
67
                labels = [self._id_to_label(sid) for sid in subject_ids
68
                          if sid is not None]
69
                if labels:
70
                    print(' '.join(labels), doc, file=trainfile)
71
                else:
72
                    self.warning('no labels for document "{}"'.format(doc))
73
74
    @classmethod
75
    def _normalize_text(cls, project, text):
76
        return ' '.join(project.analyzer.tokenize_words(text))
77
78
    def _create_train_file(self, corpus, project):
79
        self.info('creating fastText training file')
80
81
        doc_subjects = collections.defaultdict(set)
82
83
        for doc in corpus.documents:
84
            text = self._normalize_text(project, doc.text)
85
            if text == '':
86
                continue
87
            doc_subjects[text] = [project.subjects.by_uri(uri)
88
                                  for uri in doc.uris]
89
90
        annif.util.atomic_save(doc_subjects,
91
                               self._get_datadir(),
92
                               self.TRAIN_FILE,
93
                               method=self._write_train_file)
94
95
    def _create_model(self):
96
        self.info('creating fastText model')
97
        trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
98
        modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
99
        params = {param: self.FASTTEXT_PARAMS[param](val)
100
                  for param, val in self.params.items()
101
                  if param in self.FASTTEXT_PARAMS}
102
        self._model = fastText.train_supervised(trainpath, **params)
103
        self._model.save_model(modelpath)
104
105
    def train(self, corpus, project):
106
        self._create_train_file(corpus, project)
107
        self._create_model()
108
109
    def _analyze_chunks(self, chunktexts, project):
110
        limit = int(self.params['limit'])
111
        chunklabels, chunkscores = self._model.predict(chunktexts, limit)
112
        label_scores = collections.defaultdict(float)
113
        for labels, scores in zip(chunklabels, chunkscores):
114
            for label, score in zip(labels, scores):
115
                label_scores[label] += score
116
        best_labels = sorted([(score, label)
117
                              for label, score in label_scores.items()],
118
                             reverse=True)
119
120
        results = []
121
        for score, label in best_labels[:limit]:
122
            subject = self._label_to_subject(project, label)
123
            results.append(AnalysisHit(
124
                uri=subject[0],
125
                label=subject[1],
126
                score=score / len(chunktexts)))
127
        return ListAnalysisResult(results, project.subjects)
128
129
    def _analyze(self, text, project, params):
130
        self.initialize()
131
        self.debug('Analyzing text "{}..." (len={})'.format(
132
            text[:20], len(text)))
133
        sentences = project.analyzer.tokenize_sentences(text)
134
        self.debug('Found {} sentences'.format(len(sentences)))
135
        chunksize = int(params['chunksize'])
136
        chunktexts = []
137
        for i in range(0, len(sentences), chunksize):
138
            chunktext = ' '.join(sentences[i:i + chunksize])
139
            normalized = self._normalize_text(project, chunktext)
140
            if normalized != '':
141
                chunktexts.append(normalized)
142
        self.debug('Split sentences into {} chunks'.format(len(chunktexts)))
143
144
        return self._analyze_chunks(chunktexts, project)
145