Passed
Pull Request — master (#438)
by
unknown
01:38
created

annif.backend.stwfsapy.StwfsapyBackend._suggest()   A

Complexity

Conditions 3

Size

Total Lines 18
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 18
rs 9.55
c 0
b 0
f 0
cc 3
nop 3
1
import os
2
from stwfsapy.predictor import StwfsapyPredictor
3
from annif.exception import NotInitializedException, NotSupportedException
4
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
5
from . import backend
6
from annif.util import boolean
7
8
9
_KEY_CONCEPT_TYPE_URI = 'concept_type_uri'
10
_KEY_SUBTHESAURUS_TYPE_URI = 'sub_thesaurus_type_uri'
11
_KEY_THESAURUS_RELATION_TYPE_URI = 'thesaurus_relation_type_uri'
12
_KEY_THESAURUS_RELATION_IS_SPECIALISATION = (
13
    'thesaurus_relation_is_specialisation')
14
_KEY_REMOVE_DEPRECATED = 'remove_deprecated'
15
_KEY_HANDLE_TITLE_CASE = 'handle_title_case'
16
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES = 'extract_upper_case_from_braces'
17
_KEY_EXTRACT_ANY_CASE_FROM_BRACES = 'extract_any_case_from_braces'
18
_KEY_EXPAND_AMPERSAND_WITH_SPACES = 'expand_ampersand_with_spaces'
19
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION = (
20
    'expand_abbreviation_with_punctuation')
21
_KEY_SIMPLE_ENGLISH_PLURAL_RULES = 'simple_english_plural_rules'
22
23
24
class StwfsapyBackend(backend.AnnifBackend):
25
26
    name = "stwfsapy"
27
    needs_subject_index = True
28
29
    STWFSAPY_PARAMETERS = {
30
        _KEY_CONCEPT_TYPE_URI: str,
31
        _KEY_SUBTHESAURUS_TYPE_URI: str,
32
        _KEY_THESAURUS_RELATION_TYPE_URI: str,
33
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: boolean,
34
        _KEY_REMOVE_DEPRECATED: boolean,
35
        _KEY_HANDLE_TITLE_CASE: boolean,
36
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: boolean,
37
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: boolean,
38
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: boolean,
39
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: boolean,
40
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: boolean,
41
    }
42
43
    DEFAULT_PARAMETERS = {
44
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: False,
45
        _KEY_REMOVE_DEPRECATED: True,
46
        _KEY_HANDLE_TITLE_CASE: True,
47
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: True,
48
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: False,
49
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: True,
50
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: True,
51
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: False,
52
    }
53
54
    MODEL_FILE = 'stwfsapy_predictor.zip'
55
56
    _model = None
57
58
    def initialize(self):
59
        if self._model is None:
60
            path = os.path.join(self.datadir, self.MODEL_FILE)
61
            self.debug(f'Loading STWFSAPY model from {path}.')
62
            if os.path.exists(path):
63
                self._model = StwfsapyPredictor.load(path)
64
                self.debug('Loaded model.')
65
            else:
66
                raise NotInitializedException(
67
                    f'Model not found at {path}',
68
                    backend_id=self.backend_id)
69
70
    def _train(self, corpus, params):
71
        if corpus == 'cached':
72
            raise NotSupportedException(
73
                'Training stwfsapy project from cached data not supported.')
74
        if corpus.is_empty():
75
            raise NotSupportedException(
76
                'Cannot train stwfsapy project with no documents.')
77
        self.debug("Transforming training data.")
78
        X = []
79
        y = []
80
        for doc in corpus.documents:
81
            X.append(doc.text)
82
            y.append(doc.uris)
83
        graph = self.project.vocab.as_graph()
84
        new_params = {
85
                key: self.STWFSAPY_PARAMETERS[key](val)
86
                for key, val
87
                in params.items()
88
                if key in self.STWFSAPY_PARAMETERS
89
            }
90
        p = StwfsapyPredictor(
91
            graph=graph,
92
            langs=frozenset([params['language']]),
93
            **new_params)
94
        p.fit(X, y)
95
        self._model = p
96
        p.store(os.path.join(self.datadir, self.MODEL_FILE))
97
98
    def _suggest(self, text, params):
99
        self.debug(
100
            f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
101
        result = self._model.suggest_proba([text])[0]
102
        suggestions = []
103
        for uri, score in result:
104
            subject_id = self.project.subjects.by_uri(uri)
105
            if subject_id:
106
                label = self.project.subjects[subject_id][1]
107
            else:
108
                label = None
109
            suggestion = SubjectSuggestion(
110
                uri,
111
                label,
112
                None,
113
                score)
114
            suggestions.append(suggestion)
115
        return ListSuggestionResult(suggestions)
116