annif.backend.yake.YakeBackend.label_types()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Annif backend using Yake keyword extraction"""
2
3
# For license remarks of this backend see README.md:
4
# https://github.com/NatLibFi/Annif#license.
5
from __future__ import annotations
6
7
import os.path
8
import re
9
from collections import defaultdict
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import yake
14
from rdflib.namespace import SKOS
15
16
import annif.util
17
from annif.exception import ConfigurationException, NotSupportedException
18
from annif.suggestion import SubjectSuggestion
19
20
from . import backend
21
22
if TYPE_CHECKING:
23
    from rdflib.term import URIRef
24
25
    from annif.corpus import Document, DocumentCorpus
26
27
28
class YakeBackend(backend.AnnifBackend):
29
    """Yake based backend for Annif"""
30
31
    name = "yake"
32
33
    # defaults for uninitialized instances
34
    _index = None
35
    _graph = None
36
    INDEX_FILE = "yake-index"
37
38
    DEFAULT_PARAMETERS = {
39
        "max_ngram_size": 4,
40
        "deduplication_threshold": 0.9,
41
        "deduplication_algo": "levs",
42
        "window_size": 1,
43
        "num_keywords": 100,
44
        "features": None,
45
        "label_types": ["prefLabel", "altLabel"],
46
        "remove_parentheses": False,
47
    }
48
49
    @property
50
    def is_trained(self):
51
        return True
52
53
    @property
54
    def label_types(self) -> list[URIRef]:
55
        if isinstance(self.params["label_types"], str):  # Label types set by user
56
            label_types = [lt.strip() for lt in self.params["label_types"].split(",")]
57
            self._validate_label_types(label_types)
58
        else:
59
            label_types = self.params["label_types"]  # The defaults
60
        return [getattr(SKOS, lt) for lt in label_types]
61
62
    def _validate_label_types(self, label_types: list[str]) -> None:
63
        for lt in label_types:
64
            if lt not in ("prefLabel", "altLabel", "hiddenLabel"):
65
                raise ConfigurationException(
66
                    f"invalid label type {lt}", backend_id=self.backend_id
67
                )
68
69
    def initialize(self, parallel: bool = False) -> None:
70
        self._initialize_index()
71
72
    def _initialize_index(self) -> None:
73
        if self._index is None:
74
            path = os.path.join(self.datadir, self.INDEX_FILE)
75
            if os.path.exists(path):
76
                self._index = joblib.load(path)
77
                self.debug(f"Loaded index from {path} with {len(self._index)} labels")
78
            else:
79
                self.info("Creating index")
80
                self._index = self._create_index()
81
                self._save_index(path)
82
                self.info(f"Created index with {len(self._index)} labels")
83
84
    def _save_index(self, path: str) -> None:
85
        annif.util.atomic_save(
86
            self._index, self.datadir, self.INDEX_FILE, method=joblib.dump
87
        )
88
89
    def _create_index(self) -> dict[str, set[str]]:
90
        index = defaultdict(set)
91
        skos_vocab = self.project.vocab.skos
92
        for concept in skos_vocab.concepts:
93
            uri = str(concept)
94
            labels_by_lang = skos_vocab.get_concept_labels(concept, self.label_types)
95
            for label in labels_by_lang[self.params["language"]]:
96
                label = self._normalize_label(label)
97
                index[label].add(uri)
98
        index.pop("", None)  # Remove possible empty string entry
99
        return dict(index)
100
101
    def _normalize_label(self, label: str) -> str:
102
        label = str(label)
103
        if annif.util.boolean(self.params["remove_parentheses"]):
104
            label = re.sub(r" \(.*\)", "", label)
105
        normalized_label = self._normalize_phrase(label)
106
        return self._sort_phrase(normalized_label)
107
108
    def _normalize_phrase(self, phrase: str) -> str:
109
        return " ".join(self.project.analyzer.tokenize_words(phrase, filter=False))
110
111
    def _sort_phrase(self, phrase: str) -> str:
112
        words = phrase.split()
113
        return " ".join(sorted(words))
114
115
    def _suggest(
116
        self, doc: Document, params: dict[str, Any]
117
    ) -> list[SubjectSuggestion]:
118
        self.debug(
119
            f'Suggesting subjects for text "{doc.text[:20]}..." (len={len(doc.text)})'
120
        )
121
        limit = int(params["limit"])
122
123
        self._kw_extractor = yake.KeywordExtractor(
124
            lan=params["language"],
125
            n=int(params["max_ngram_size"]),
126
            dedupLim=float(params["deduplication_threshold"]),
127
            dedupFunc=params["deduplication_algo"],
128
            windowsSize=int(params["window_size"]),
129
            top=int(params["num_keywords"]),
130
            features=self.params["features"],
131
        )
132
        keyphrases = self._kw_extractor.extract_keywords(doc.text)
133
        suggestions = self._keyphrases2suggestions(keyphrases)
134
135
        subject_suggestions = [
136
            SubjectSuggestion(subject_id=self.project.subjects.by_uri(uri), score=score)
137
            for uri, score in suggestions[:limit]
138
            if score > 0.0
139
        ]
140
        return subject_suggestions
141
142
    def _keyphrases2suggestions(
143
        self, keyphrases: list[tuple[str, float]]
144
    ) -> list[tuple[str, float]]:
145
        suggestions = []
146
        not_matched = []
147
        for kp, score in keyphrases:
148
            uris = self._keyphrase2uris(kp)
149
            for uri in uris:
150
                suggestions.append((uri, self._transform_score(score)))
151
            if not uris:
152
                not_matched.append((kp, self._transform_score(score)))
153
        # Remove duplicate uris, conflating the scores
154
        suggestions = self._combine_suggestions(suggestions)
155
        self.debug(
156
            "Keyphrases not matched:\n"
157
            + "\t".join(
158
                [
159
                    kp[0] + " " + str(kp[1])
160
                    for kp in sorted(not_matched, reverse=True, key=lambda kp: kp[1])
161
                ]
162
            )
163
        )
164
        return suggestions
165
166
    def _keyphrase2uris(self, keyphrase: str) -> set[str]:
167
        keyphrase = self._normalize_phrase(keyphrase)
168
        keyphrase = self._sort_phrase(keyphrase)
169
        return self._index.get(keyphrase, [])
170
171
    def _transform_score(self, score: float) -> float:
172
        score = max(score, 0)
173
        return 1.0 / (score + 1)
174
175
    def _combine_suggestions(
176
        self, suggestions: list[tuple[str, float]]
177
    ) -> list[tuple[str, float]]:
178
        combined_suggestions = {}
179
        for uri, score in suggestions:
180
            if uri not in combined_suggestions:
181
                combined_suggestions[uri] = score
182
            else:
183
                old_score = combined_suggestions[uri]
184
                combined_suggestions[uri] = self._combine_scores(score, old_score)
185
        return list(combined_suggestions.items())
186
187
    def _combine_scores(self, score1: float, score2: float) -> float:
188
        # The result is never smaller than the greater input
189
        score1 = score1 / 2 + 0.5
190
        score2 = score2 / 2 + 0.5
191
        confl = score1 * score2 / (score1 * score2 + (1 - score1) * (1 - score2))
192
        return (confl - 0.5) * 2
193
194
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
195
        raise NotSupportedException("Training yake backend is not possible.")
196