Passed
Pull Request — master (#491)
by
unknown
03:22
created

annif.backend.yake.YakeBackend._train()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Annif backend using Yake keyword extraction"""
2
# For license remarks of this backend see README.md:
3
# https://github.com/NatLibFi/Annif#license.
4
5
import yake
6
import joblib
7
import os.path
8
import re
9
from collections import defaultdict
10
from rdflib.namespace import SKOS
11
import annif.util
12
from . import backend
13
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
14
from annif.exception import ConfigurationException
15
from annif.exception import NotSupportedException
16
17
18
class YakeBackend(backend.AnnifBackend):
19
    """Yake based backend for Annif"""
20
    name = "yake"
21
    needs_subject_index = False
22
23
    # defaults for uninitialized instances
24
    _index = None
25
    _graph = None
26
    INDEX_FILE = 'yake-index'
27
28
    DEFAULT_PARAMETERS = {
29
        'max_ngram_size': 4,
30
        'deduplication_threshold': 0.9,
31
        'deduplication_algo': 'levs',
32
        'window_size': 1,
33
        'num_keywords': 100,
34
        'features': None,
35
        'label_types': ['prefLabel', 'altLabel'],
36
        'remove_parentheses': False
37
    }
38
39
    def default_params(self):
40
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
41
        params.update(self.DEFAULT_PARAMETERS)
42
        return params
43
44
    @property
45
    def is_trained(self):
46
        return True
47
48
    @property
49
    def label_types(self):
50
        if type(self.params['label_types']) == str:  # Label types set by user
51
            label_types = [lt.strip() for lt
52
                           in self.params['label_types'].split(',')]
53
            self._validate_label_types(label_types)
54
        else:
55
            label_types = self.params['label_types']  # The defaults
56
        return [getattr(SKOS, lt) for lt in label_types]
57
58
    def _validate_label_types(self, label_types):
59
        for lt in label_types:
60
            if lt not in ('prefLabel', 'altLabel', 'hiddenLabel'):
61
                raise ConfigurationException(
62
                    f'invalid label type {lt}', backend_id=self.backend_id)
63
64
    def initialize(self):
65
        self._initialize_index()
66
67
    def _initialize_index(self):
68
        if self._index is None:
69
            path = os.path.join(self.datadir, self.INDEX_FILE)
70
            if os.path.exists(path):
71
                self._index = joblib.load(path)
72
                self.debug(
73
                    f'Loaded index from {path} with {len(self._index)} labels')
74
            else:
75
                self.info('Creating index')
76
                self._index = self._create_index()
77
                self._save_index(path)
78
                self.info(f'Created index with {len(self._index)} labels')
79
80
    def _save_index(self, path):
81
        annif.util.atomic_save(
82
            self._index,
83
            self.datadir,
84
            self.INDEX_FILE,
85
            method=joblib.dump)
86
87
    def _create_index(self):
88
        index = defaultdict(set)
89
        skos_vocab = self.project.vocab.skos
90
        for concept in skos_vocab.concepts:
91
            uri = str(concept)
92
            labels = skos_vocab.get_concept_labels(
93
                concept, self.label_types, self.params['language'])
94
            for label in labels:
95
                label = self._normalize_label(label)
96
                index[label].add(uri)
97
        index.pop('', None)  # Remove possible empty string entry
98
        return dict(index)
99
100
    def _normalize_label(self, label):
101
        label = str(label)
102
        if annif.util.boolean(self.params['remove_parentheses']):
103
            label = re.sub(r' \(.*\)', '', label)
104
        normalized_label = self._normalize_phrase(label)
105
        return self._sort_phrase(normalized_label)
106
107
    def _normalize_phrase(self, phrase):
108
        normalized = []
109
        for word in phrase.split():
110
            normalized.append(
111
                self.project.analyzer.normalize_word(word).lower())
112
        return ' '.join(normalized)
113
114
    def _sort_phrase(self, phrase):
115
        words = phrase.split()
116
        return ' '.join(sorted(words))
117
118
    def _suggest(self, text, params):
119
        self.debug(
120
            f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
121
        limit = int(params['limit'])
122
123
        self._kw_extractor = yake.KeywordExtractor(
124
            lan=params['language'],
125
            n=int(params['max_ngram_size']),
126
            dedupLim=float(params['deduplication_threshold']),
127
            dedupFunc=params['deduplication_algo'],
128
            windowsSize=int(params['window_size']),
129
            top=int(params['num_keywords']),
130
            features=self.params['features'])
131
        keyphrases = self._kw_extractor.extract_keywords(text)
132
        suggestions = self._keyphrases2suggestions(keyphrases)
133
134
        subject_suggestions = [SubjectSuggestion(
135
                uri=uri,
136
                label=None,
137
                notation=None,
138
                score=score)
139
                for uri, score in suggestions[:limit] if score > 0.0]
140
        return ListSuggestionResult.create_from_index(subject_suggestions,
141
                                                      self.project.subjects)
142
143
    def _keyphrases2suggestions(self, keyphrases):
144
        suggestions = []
145
        not_matched = []
146
        for kp, score in keyphrases:
147
            uris = self._keyphrase2uris(kp)
148
            for uri in uris:
149
                suggestions.append(
150
                    (uri, self._transform_score(score)))
151
            if not uris:
152
                not_matched.append((kp, self._transform_score(score)))
153
        # Remove duplicate uris, conflating the scores
154
        suggestions = self._combine_suggestions(suggestions)
155
        self.debug('Keyphrases not matched:\n' + '\t'.join(
156
            [kp[0] + ' ' + str(kp[1]) for kp
157
             in sorted(not_matched, reverse=True, key=lambda kp: kp[1])]))
158
        return suggestions
159
160
    def _keyphrase2uris(self, keyphrase):
161
        keyphrase = self._normalize_phrase(keyphrase)
162
        keyphrase = self._sort_phrase(keyphrase)
163
        return self._index.get(keyphrase, [])
164
165
    def _transform_score(self, score):
166
        score = max(score, 0)
167
        return 1.0 / (score + 1)
168
169
    def _combine_suggestions(self, suggestions):
170
        combined_suggestions = {}
171
        for uri, score in suggestions:
172
            if uri not in combined_suggestions:
173
                combined_suggestions[uri] = score
174
            else:
175
                old_score = combined_suggestions[uri]
176
                combined_suggestions[uri] = self._combine_scores(
177
                    score, old_score)
178
        return list(combined_suggestions.items())
179
180
    def _combine_scores(self, score1, score2):
181
        # The result is never smaller than the greater input
182
        score1 = score1/2 + 0.5
183
        score2 = score2/2 + 0.5
184
        confl = score1 * score2 / (score1 * score2 + (1-score1) * (1-score2))
185
        return (confl-0.5) * 2
186
187
    def _train(self, corpus, params):
188
        raise NotSupportedException(
189
            'Training yake backend is not possible.')
190