Passed
Push — issue703-python-3.11-support ( f59527...05d52a )
by Juho
04:06 queued 14s
created

annif.backend.yake.YakeBackend.default_params()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Annif backend using Yake keyword extraction"""
2
# For license remarks of this backend see README.md:
3
# https://github.com/NatLibFi/Annif#license.
4
from __future__ import annotations
5
6
import os.path
7
import re
8
from collections import defaultdict
9
from typing import TYPE_CHECKING, Any
10
11
import joblib
12
import yake
13
from rdflib.namespace import SKOS
14
15
import annif.util
16
from annif.exception import ConfigurationException, NotSupportedException
17
from annif.suggestion import SubjectSuggestion
18
19
from . import backend
20
21
if TYPE_CHECKING:
22
    from rdflib.term import URIRef
23
24
    from annif.corpus.document import DocumentCorpus
25
26
27
class YakeBackend(backend.AnnifBackend):
28
    """Yake based backend for Annif"""
29
30
    name = "yake"
31
32
    # defaults for uninitialized instances
33
    _index = None
34
    _graph = None
35
    INDEX_FILE = "yake-index"
36
37
    DEFAULT_PARAMETERS = {
38
        "max_ngram_size": 4,
39
        "deduplication_threshold": 0.9,
40
        "deduplication_algo": "levs",
41
        "window_size": 1,
42
        "num_keywords": 100,
43
        "features": None,
44
        "label_types": ["prefLabel", "altLabel"],
45
        "remove_parentheses": False,
46
    }
47
48
    @property
49
    def is_trained(self):
50
        return True
51
52
    @property
53
    def label_types(self) -> list[URIRef]:
54
        if type(self.params["label_types"]) == str:  # Label types set by user
55
            label_types = [lt.strip() for lt in self.params["label_types"].split(",")]
56
            self._validate_label_types(label_types)
57
        else:
58
            label_types = self.params["label_types"]  # The defaults
59
        return [getattr(SKOS, lt) for lt in label_types]
60
61
    def _validate_label_types(self, label_types: list[str]) -> None:
62
        for lt in label_types:
63
            if lt not in ("prefLabel", "altLabel", "hiddenLabel"):
64
                raise ConfigurationException(
65
                    f"invalid label type {lt}", backend_id=self.backend_id
66
                )
67
68
    def initialize(self, parallel: bool = False) -> None:
69
        self._initialize_index()
70
71
    def _initialize_index(self) -> None:
72
        if self._index is None:
73
            path = os.path.join(self.datadir, self.INDEX_FILE)
74
            if os.path.exists(path):
75
                self._index = joblib.load(path)
76
                self.debug(f"Loaded index from {path} with {len(self._index)} labels")
77
            else:
78
                self.info("Creating index")
79
                self._index = self._create_index()
80
                self._save_index(path)
81
                self.info(f"Created index with {len(self._index)} labels")
82
83
    def _save_index(self, path: str) -> None:
84
        annif.util.atomic_save(
85
            self._index, self.datadir, self.INDEX_FILE, method=joblib.dump
86
        )
87
88
    def _create_index(self) -> dict[str, set[str]]:
89
        index = defaultdict(set)
90
        skos_vocab = self.project.vocab.skos
91
        for concept in skos_vocab.concepts:
92
            uri = str(concept)
93
            labels_by_lang = skos_vocab.get_concept_labels(concept, self.label_types)
94
            for label in labels_by_lang[self.params["language"]]:
95
                label = self._normalize_label(label)
96
                index[label].add(uri)
97
        index.pop("", None)  # Remove possible empty string entry
98
        return dict(index)
99
100
    def _normalize_label(self, label: str) -> str:
101
        label = str(label)
102
        if annif.util.boolean(self.params["remove_parentheses"]):
103
            label = re.sub(r" \(.*\)", "", label)
104
        normalized_label = self._normalize_phrase(label)
105
        return self._sort_phrase(normalized_label)
106
107
    def _normalize_phrase(self, phrase: str) -> str:
108
        return " ".join(self.project.analyzer.tokenize_words(phrase, filter=False))
109
110
    def _sort_phrase(self, phrase: str) -> str:
111
        words = phrase.split()
112
        return " ".join(sorted(words))
113
114
    def _suggest(self, text: str, params: dict[str, Any]) -> list[SubjectSuggestion]:
115
        self.debug(f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
116
        limit = int(params["limit"])
117
118
        self._kw_extractor = yake.KeywordExtractor(
119
            lan=params["language"],
120
            n=int(params["max_ngram_size"]),
121
            dedupLim=float(params["deduplication_threshold"]),
122
            dedupFunc=params["deduplication_algo"],
123
            windowsSize=int(params["window_size"]),
124
            top=int(params["num_keywords"]),
125
            features=self.params["features"],
126
        )
127
        keyphrases = self._kw_extractor.extract_keywords(text)
128
        suggestions = self._keyphrases2suggestions(keyphrases)
129
130
        subject_suggestions = [
131
            SubjectSuggestion(subject_id=self.project.subjects.by_uri(uri), score=score)
132
            for uri, score in suggestions[:limit]
133
            if score > 0.0
134
        ]
135
        return subject_suggestions
136
137
    def _keyphrases2suggestions(
138
        self, keyphrases: list[tuple[str, float]]
139
    ) -> list[tuple[str, float]]:
140
        suggestions = []
141
        not_matched = []
142
        for kp, score in keyphrases:
143
            uris = self._keyphrase2uris(kp)
144
            for uri in uris:
145
                suggestions.append((uri, self._transform_score(score)))
146
            if not uris:
147
                not_matched.append((kp, self._transform_score(score)))
148
        # Remove duplicate uris, conflating the scores
149
        suggestions = self._combine_suggestions(suggestions)
150
        self.debug(
151
            "Keyphrases not matched:\n"
152
            + "\t".join(
153
                [
154
                    kp[0] + " " + str(kp[1])
155
                    for kp in sorted(not_matched, reverse=True, key=lambda kp: kp[1])
156
                ]
157
            )
158
        )
159
        return suggestions
160
161
    def _keyphrase2uris(self, keyphrase: str) -> set[str]:
162
        keyphrase = self._normalize_phrase(keyphrase)
163
        keyphrase = self._sort_phrase(keyphrase)
164
        return self._index.get(keyphrase, [])
165
166
    def _transform_score(self, score: float) -> float:
167
        score = max(score, 0)
168
        return 1.0 / (score + 1)
169
170
    def _combine_suggestions(
171
        self, suggestions: list[tuple[str, float]]
172
    ) -> list[tuple[str, float]]:
173
        combined_suggestions = {}
174
        for uri, score in suggestions:
175
            if uri not in combined_suggestions:
176
                combined_suggestions[uri] = score
177
            else:
178
                old_score = combined_suggestions[uri]
179
                combined_suggestions[uri] = self._combine_scores(score, old_score)
180
        return list(combined_suggestions.items())
181
182
    def _combine_scores(self, score1: float, score2: float) -> float:
183
        # The result is never smaller than the greater input
184
        score1 = score1 / 2 + 0.5
185
        score2 = score2 / 2 + 0.5
186
        confl = score1 * score2 / (score1 * score2 + (1 - score1) * (1 - score2))
187
        return (confl - 0.5) * 2
188
189
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
190
        raise NotSupportedException("Training yake backend is not possible.")
191