annif.backend.yake.YakeBackend._normalize_phrase()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""Annif backend using Yake keyword extraction"""
2
3
# For license remarks of this backend see README.md:
4
# https://github.com/NatLibFi/Annif#license.
5
from __future__ import annotations
6
7
import os.path
8
import re
9
from collections import defaultdict
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import yake
14
from rdflib.namespace import SKOS
15
16
import annif.util
17
from annif.exception import ConfigurationException, NotSupportedException
18
from annif.suggestion import SubjectSuggestion
19
20
from . import backend
21
22
if TYPE_CHECKING:
23
    from rdflib.term import URIRef
24
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class YakeBackend(backend.AnnifBackend):
29
    """Yake based backend for Annif"""
30
31
    name = "yake"
32
33
    # defaults for uninitialized instances
34
    _index = None
35
    _graph = None
36
    INDEX_FILE = "yake-index"
37
38
    DEFAULT_PARAMETERS = {
39
        "max_ngram_size": 4,
40
        "deduplication_threshold": 0.9,
41
        "deduplication_algo": "levs",
42
        "window_size": 1,
43
        "num_keywords": 100,
44
        "features": None,
45
        "label_types": ["prefLabel", "altLabel"],
46
        "remove_parentheses": False,
47
    }
48
49
    @property
50
    def is_trained(self):
51
        return True
52
53
    @property
54
    def label_types(self) -> list[URIRef]:
55
        if isinstance(self.params["label_types"], str):  # Label types set by user
56
            label_types = [lt.strip() for lt in self.params["label_types"].split(",")]
57
            self._validate_label_types(label_types)
58
        else:
59
            label_types = self.params["label_types"]  # The defaults
60
        return [getattr(SKOS, lt) for lt in label_types]
61
62
    def _validate_label_types(self, label_types: list[str]) -> None:
63
        for lt in label_types:
64
            if lt not in ("prefLabel", "altLabel", "hiddenLabel"):
65
                raise ConfigurationException(
66
                    f"invalid label type {lt}", backend_id=self.backend_id
67
                )
68
69
    def initialize(self, parallel: bool = False) -> None:
70
        self._initialize_index()
71
72
    def _initialize_index(self) -> None:
73
        if self._index is None:
74
            path = os.path.join(self.datadir, self.INDEX_FILE)
75
            if os.path.exists(path):
76
                self._index = joblib.load(path)
77
                self.debug(f"Loaded index from {path} with {len(self._index)} labels")
78
            else:
79
                self.info("Creating index")
80
                self._index = self._create_index()
81
                self._save_index(path)
82
                self.info(f"Created index with {len(self._index)} labels")
83
84
    def _save_index(self, path: str) -> None:
85
        annif.util.atomic_save(
86
            self._index, self.datadir, self.INDEX_FILE, method=joblib.dump
87
        )
88
89
    def _create_index(self) -> dict[str, set[str]]:
90
        index = defaultdict(set)
91
        skos_vocab = self.project.vocab.skos
92
        for concept in skos_vocab.concepts:
93
            uri = str(concept)
94
            labels_by_lang = skos_vocab.get_concept_labels(concept, self.label_types)
95
            for label in labels_by_lang[self.params["language"]]:
96
                label = self._normalize_label(label)
97
                index[label].add(uri)
98
        index.pop("", None)  # Remove possible empty string entry
99
        return dict(index)
100
101
    def _normalize_label(self, label: str) -> str:
102
        label = str(label)
103
        if annif.util.boolean(self.params["remove_parentheses"]):
104
            label = re.sub(r" \(.*\)", "", label)
105
        normalized_label = self._normalize_phrase(label)
106
        return self._sort_phrase(normalized_label)
107
108
    def _normalize_phrase(self, phrase: str) -> str:
109
        return " ".join(self.project.analyzer.tokenize_words(phrase, filter=False))
110
111
    def _sort_phrase(self, phrase: str) -> str:
112
        words = phrase.split()
113
        return " ".join(sorted(words))
114
115
    def _suggest(self, text: str, params: dict[str, Any]) -> list[SubjectSuggestion]:
116
        self.debug(f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
117
        limit = int(params["limit"])
118
119
        self._kw_extractor = yake.KeywordExtractor(
120
            lan=params["language"],
121
            n=int(params["max_ngram_size"]),
122
            dedupLim=float(params["deduplication_threshold"]),
123
            dedupFunc=params["deduplication_algo"],
124
            windowsSize=int(params["window_size"]),
125
            top=int(params["num_keywords"]),
126
            features=self.params["features"],
127
        )
128
        keyphrases = self._kw_extractor.extract_keywords(text)
129
        suggestions = self._keyphrases2suggestions(keyphrases)
130
131
        subject_suggestions = [
132
            SubjectSuggestion(subject_id=self.project.subjects.by_uri(uri), score=score)
133
            for uri, score in suggestions[:limit]
134
            if score > 0.0
135
        ]
136
        return subject_suggestions
137
138
    def _keyphrases2suggestions(
139
        self, keyphrases: list[tuple[str, float]]
140
    ) -> list[tuple[str, float]]:
141
        suggestions = []
142
        not_matched = []
143
        for kp, score in keyphrases:
144
            uris = self._keyphrase2uris(kp)
145
            for uri in uris:
146
                suggestions.append((uri, self._transform_score(score)))
147
            if not uris:
148
                not_matched.append((kp, self._transform_score(score)))
149
        # Remove duplicate uris, conflating the scores
150
        suggestions = self._combine_suggestions(suggestions)
151
        self.debug(
152
            "Keyphrases not matched:\n"
153
            + "\t".join(
154
                [
155
                    kp[0] + " " + str(kp[1])
156
                    for kp in sorted(not_matched, reverse=True, key=lambda kp: kp[1])
157
                ]
158
            )
159
        )
160
        return suggestions
161
162
    def _keyphrase2uris(self, keyphrase: str) -> set[str]:
163
        keyphrase = self._normalize_phrase(keyphrase)
164
        keyphrase = self._sort_phrase(keyphrase)
165
        return self._index.get(keyphrase, [])
166
167
    def _transform_score(self, score: float) -> float:
168
        score = max(score, 0)
169
        return 1.0 / (score + 1)
170
171
    def _combine_suggestions(
172
        self, suggestions: list[tuple[str, float]]
173
    ) -> list[tuple[str, float]]:
174
        combined_suggestions = {}
175
        for uri, score in suggestions:
176
            if uri not in combined_suggestions:
177
                combined_suggestions[uri] = score
178
            else:
179
                old_score = combined_suggestions[uri]
180
                combined_suggestions[uri] = self._combine_scores(score, old_score)
181
        return list(combined_suggestions.items())
182
183
    def _combine_scores(self, score1: float, score2: float) -> float:
184
        # The result is never smaller than the greater input
185
        score1 = score1 / 2 + 0.5
186
        score2 = score2 / 2 + 0.5
187
        confl = score1 * score2 / (score1 * score2 + (1 - score1) * (1 - score2))
188
        return (confl - 0.5) * 2
189
190
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
191
        raise NotSupportedException("Training yake backend is not possible.")
192