Completed
Push — master ( 267cc8...034a4a )
by Osma
15s queued 12s
created

annif.backend.stwfsa   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 132
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 115
dl 0
loc 132
rs 10
c 0
b 0
f 0
wmc 12

4 Methods

Rating   Name   Duplication   Size   Complexity  
A StwfsaBackend.initialize() 0 11 3
A StwfsaBackend._train() 0 20 2
A StwfsaBackend._suggest() 0 18 3
A StwfsaBackend._load_data() 0 14 4
1
import os
2
from stwfsapy.predictor import StwfsapyPredictor
3
from annif.exception import NotInitializedException, NotSupportedException
4
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
5
from . import backend
6
from annif.util import atomic_save, boolean
7
8
9
_KEY_CONCEPT_TYPE_URI = 'concept_type_uri'
10
_KEY_SUBTHESAURUS_TYPE_URI = 'sub_thesaurus_type_uri'
11
_KEY_THESAURUS_RELATION_TYPE_URI = 'thesaurus_relation_type_uri'
12
_KEY_THESAURUS_RELATION_IS_SPECIALISATION = (
13
    'thesaurus_relation_is_specialisation')
14
_KEY_REMOVE_DEPRECATED = 'remove_deprecated'
15
_KEY_HANDLE_TITLE_CASE = 'handle_title_case'
16
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES = 'extract_upper_case_from_braces'
17
_KEY_EXTRACT_ANY_CASE_FROM_BRACES = 'extract_any_case_from_braces'
18
_KEY_EXPAND_AMPERSAND_WITH_SPACES = 'expand_ampersand_with_spaces'
19
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION = (
20
    'expand_abbreviation_with_punctuation')
21
_KEY_SIMPLE_ENGLISH_PLURAL_RULES = 'simple_english_plural_rules'
22
_KEY_INPUT_LIMIT = 'input_limit'
23
24
25
class StwfsaBackend(backend.AnnifBackend):
26
27
    name = "stwfsa"
28
    needs_subject_index = True
29
30
    STWFSA_PARAMETERS = {
31
        _KEY_CONCEPT_TYPE_URI: str,
32
        _KEY_SUBTHESAURUS_TYPE_URI: str,
33
        _KEY_THESAURUS_RELATION_TYPE_URI: str,
34
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: boolean,
35
        _KEY_REMOVE_DEPRECATED: boolean,
36
        _KEY_HANDLE_TITLE_CASE: boolean,
37
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: boolean,
38
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: boolean,
39
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: boolean,
40
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: boolean,
41
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: boolean,
42
        _KEY_INPUT_LIMIT: int,
43
    }
44
45
    DEFAULT_PARAMETERS = {
46
        _KEY_CONCEPT_TYPE_URI: 'http://www.w3.org/2004/02/skos/core#Concept',
47
        _KEY_SUBTHESAURUS_TYPE_URI:
48
            'http://www.w3.org/2004/02/skos/core#Collection',
49
        _KEY_THESAURUS_RELATION_TYPE_URI:
50
            'http://www.w3.org/2004/02/skos/core#member',
51
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: True,
52
        _KEY_REMOVE_DEPRECATED: True,
53
        _KEY_HANDLE_TITLE_CASE: True,
54
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: True,
55
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: False,
56
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: True,
57
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: True,
58
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: False,
59
        _KEY_INPUT_LIMIT: 0,
60
    }
61
62
    MODEL_FILE = 'stwfsa_predictor.zip'
63
64
    _model = None
65
66
    def initialize(self):
67
        if self._model is None:
68
            path = os.path.join(self.datadir, self.MODEL_FILE)
69
            self.debug(f'Loading STWFSA model from {path}.')
70
            if os.path.exists(path):
71
                self._model = StwfsapyPredictor.load(path)
72
                self.debug('Loaded model.')
73
            else:
74
                raise NotInitializedException(
75
                    f'Model not found at {path}',
76
                    backend_id=self.backend_id)
77
78
    def _load_data(self, corpus):
79
        if corpus == 'cached':
80
            raise NotSupportedException(
81
                'Training stwfsa project from cached data not supported.')
82
        if corpus.is_empty():
83
            raise NotSupportedException(
84
                'Cannot train stwfsa project with no documents.')
85
        self.debug("Transforming training data.")
86
        X = []
87
        y = []
88
        for doc in corpus.documents:
89
            X.append(doc.text)
90
            y.append(doc.uris)
91
        return X, y
92
93
    def _train(self, corpus, params):
94
        X, y = self._load_data(corpus)
95
        new_params = {
96
                key: self.STWFSA_PARAMETERS[key](val)
97
                for key, val
98
                in params.items()
99
                if key in self.STWFSA_PARAMETERS
100
            }
101
        new_params.pop(_KEY_INPUT_LIMIT)
102
        p = StwfsapyPredictor(
103
            graph=self.project.vocab.as_graph(),
104
            langs=frozenset([params['language']]),
105
            **new_params)
106
        p.fit(X, y)
107
        self._model = p
108
        atomic_save(
109
            p,
110
            self.datadir,
111
            self.MODEL_FILE,
112
            lambda model, store_path: model.store(store_path))
113
114
    def _suggest(self, text, params):
115
        self.debug(
116
            f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
117
        result = self._model.suggest_proba([text])[0]
118
        suggestions = []
119
        for uri, score in result:
120
            subject_id = self.project.subjects.by_uri(uri)
121
            if subject_id:
122
                label = self.project.subjects[subject_id][1]
123
            else:
124
                label = None
125
            suggestion = SubjectSuggestion(
126
                uri,
127
                label,
128
                None,
129
                score)
130
            suggestions.append(suggestion)
131
        return ListSuggestionResult(suggestions)
132