annif.backend.stwfsa.StwfsaBackend._suggest()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 10
nop 3
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import os
4
from typing import TYPE_CHECKING, Any
5
6
from stwfsapy.predictor import StwfsapyPredictor
7
8
from annif.exception import NotInitializedException, NotSupportedException
9
from annif.suggestion import SubjectSuggestion
10
from annif.util import atomic_save, boolean
11
12
from . import backend
13
14
if TYPE_CHECKING:
15
    from annif.corpus.document import DocumentCorpus
16
17
_KEY_CONCEPT_TYPE_URI = "concept_type_uri"
18
_KEY_SUBTHESAURUS_TYPE_URI = "sub_thesaurus_type_uri"
19
_KEY_THESAURUS_RELATION_TYPE_URI = "thesaurus_relation_type_uri"
20
_KEY_THESAURUS_RELATION_IS_SPECIALISATION = "thesaurus_relation_is_specialisation"
21
_KEY_REMOVE_DEPRECATED = "remove_deprecated"
22
_KEY_HANDLE_TITLE_CASE = "handle_title_case"
23
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES = "extract_upper_case_from_braces"
24
_KEY_EXTRACT_ANY_CASE_FROM_BRACES = "extract_any_case_from_braces"
25
_KEY_EXPAND_AMPERSAND_WITH_SPACES = "expand_ampersand_with_spaces"
26
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION = "expand_abbreviation_with_punctuation"
27
_KEY_SIMPLE_ENGLISH_PLURAL_RULES = "simple_english_plural_rules"
28
_KEY_USE_TXT_VEC = "use_txt_vec"
29
30
31
class StwfsaBackend(backend.AnnifBackend):
32
    name = "stwfsa"
33
34
    STWFSA_PARAMETERS = {
35
        _KEY_CONCEPT_TYPE_URI: str,
36
        _KEY_SUBTHESAURUS_TYPE_URI: str,
37
        _KEY_THESAURUS_RELATION_TYPE_URI: str,
38
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: boolean,
39
        _KEY_REMOVE_DEPRECATED: boolean,
40
        _KEY_HANDLE_TITLE_CASE: boolean,
41
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: boolean,
42
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: boolean,
43
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: boolean,
44
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: boolean,
45
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: boolean,
46
        _KEY_USE_TXT_VEC: bool,
47
    }
48
49
    DEFAULT_PARAMETERS = {
50
        _KEY_CONCEPT_TYPE_URI: "http://www.w3.org/2004/02/skos/core#Concept",
51
        _KEY_SUBTHESAURUS_TYPE_URI: "http://www.w3.org/2004/02/skos/core#Collection",
52
        _KEY_THESAURUS_RELATION_TYPE_URI: "http://www.w3.org/2004/02/skos/core#member",
53
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: True,
54
        _KEY_REMOVE_DEPRECATED: True,
55
        _KEY_HANDLE_TITLE_CASE: True,
56
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: True,
57
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: False,
58
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: True,
59
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: True,
60
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: False,
61
        _KEY_USE_TXT_VEC: False,
62
    }
63
64
    MODEL_FILE = "stwfsa_predictor.zip"
65
66
    _model = None
67
68
    def initialize(self, parallel: bool = False) -> None:
69
        if self._model is None:
70
            path = os.path.join(self.datadir, self.MODEL_FILE)
71
            self.debug(f"Loading STWFSA model from {path}.")
72
            if os.path.exists(path):
73
                self._model = StwfsapyPredictor.load(path)
74
                self.debug("Loaded model.")
75
            else:
76
                raise NotInitializedException(
77
                    f"Model not found at {path}", backend_id=self.backend_id
78
                )
79
80
    def _load_data(self, corpus: DocumentCorpus) -> tuple[list[str], list[list[str]]]:
81
        if corpus == "cached":
82
            raise NotSupportedException(
83
                "Training stwfsa project from cached data not supported."
84
            )
85
        if corpus.is_empty():
86
            raise NotSupportedException(
87
                "Cannot train stwfsa project with no documents."
88
            )
89
        self.debug("Transforming training data.")
90
        X = []
91
        y = []
92
        for doc in corpus.documents:
93
            X.append(doc.text)
94
            y.append(
95
                [
96
                    self.project.subjects[subject_id].uri
97
                    for subject_id in doc.subject_set
98
                ]
99
            )
100
        return X, y
101
102
    def _train(
103
        self,
104
        corpus: DocumentCorpus,
105
        params: dict[str, Any],
106
        jobs: int = 0,
107
    ) -> None:
108
        X, y = self._load_data(corpus)
109
        new_params = {
110
            key: self.STWFSA_PARAMETERS[key](val)
111
            for key, val in params.items()
112
            if key in self.STWFSA_PARAMETERS
113
        }
114
        p = StwfsapyPredictor(
115
            graph=self.project.vocab.as_graph(),
116
            langs=frozenset([params["language"]]),
117
            **new_params,
118
        )
119
        p.fit(X, y)
120
        self._model = p
121
        atomic_save(
122
            p,
123
            self.datadir,
124
            self.MODEL_FILE,
125
            lambda model, store_path: model.store(store_path),
126
        )
127
128
    def _suggest(self, text: str, params: dict[str, Any]) -> list[SubjectSuggestion]:
129
        self.debug(f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
130
        result = self._model.suggest_proba([text])[0]
131
        suggestions = []
132
        for uri, score in result:
133
            subject_id = self.project.subjects.by_uri(uri)
134
            if subject_id is not None:
135
                suggestions.append(
136
                    SubjectSuggestion(subject_id=subject_id, score=score)
137
                )
138
        return suggestions
139