Passed
Pull Request — master (#438)
by
unknown
08:46
created

annif.backend.stwfsapy   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 120
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 9
eloc 104
dl 0
loc 120
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A StwfsapyBackend.initialize() 0 11 3
A StwfsapyBackend._suggest() 0 18 3
A StwfsapyBackend._train() 0 27 3
1
import os
2
from rdflib import Graph
3
from rdflib.util import guess_format
4
from stwfsapy.predictor import StwfsapyPredictor
5
from annif.exception import NotInitializedException, NotSupportedException
6
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
7
from . import backend
8
from annif.util import boolean
9
10
11
_KEY_GRAPH_PATH = 'graph_path'
12
_KEY_CONCEPT_TYPE_URI = 'concept_type_uri'
13
_KEY_SUBTHESAURUS_TYPE_URI = 'sub_thesaurus_type_uri'
14
_KEY_THESAURUS_RELATION_TYPE_URI = 'thesaurus_relation_type_uri'
15
_KEY_THESAURUS_RELATION_IS_SPECIALISATION = (
16
    'thesaurus_relation_is_specialisation')
17
_KEY_REMOVE_DEPRECATED = 'remove_deprecated'
18
_KEY_HANDLE_TITLE_CASE = 'handle_title_case'
19
_KEY_EXTRACT_UPPER_CASE_FROM_BRACES = 'extract_upper_case_from_braces'
20
_KEY_EXTRACT_ANY_CASE_FROM_BRACES = 'extract_any_case_from_braces'
21
_KEY_EXPAND_AMPERSAND_WITH_SPACES = 'expand_ampersand_with_spaces'
22
_KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION = (
23
    'expand_abbreviation_with_punctuation')
24
_KEY_SIMPLE_ENGLISH_PLURAL_RULES = 'simple_english_plural_rules'
25
26
27
class StwfsapyBackend(backend.AnnifBackend):
28
29
    name = "stwfsapy"
30
    needs_subject_index = False
31
32
    STWFSAPY_PARAMETERS = {
33
        _KEY_GRAPH_PATH: str,
34
        _KEY_CONCEPT_TYPE_URI: str,
35
        _KEY_SUBTHESAURUS_TYPE_URI: str,
36
        _KEY_THESAURUS_RELATION_TYPE_URI: str,
37
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: boolean,
38
        _KEY_REMOVE_DEPRECATED: boolean,
39
        _KEY_HANDLE_TITLE_CASE: boolean,
40
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: boolean,
41
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: boolean,
42
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: boolean,
43
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: boolean,
44
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: boolean,
45
    }
46
47
    DEFAULT_PARAMETERS = {
48
        _KEY_THESAURUS_RELATION_IS_SPECIALISATION: False,
49
        _KEY_REMOVE_DEPRECATED: True,
50
        _KEY_HANDLE_TITLE_CASE: True,
51
        _KEY_EXTRACT_UPPER_CASE_FROM_BRACES: True,
52
        _KEY_EXTRACT_ANY_CASE_FROM_BRACES: False,
53
        _KEY_EXPAND_AMPERSAND_WITH_SPACES: True,
54
        _KEY_EXPAND_ABBREVIATION_WITH_PUNCTUATION: True,
55
        _KEY_SIMPLE_ENGLISH_PLURAL_RULES: False,
56
    }
57
58
    MODEL_FILE = 'stwfsapy_predictor.zip'
59
60
    _model = None
61
62
    def initialize(self):
63
        if self._model is None:
64
            path = os.path.join(self.datadir, self.MODEL_FILE)
65
            self.debug(f'Loading STWFSAPY model from {path}.')
66
            if os.path.exists(path):
67
                self._model = StwfsapyPredictor.load(path)
68
                self.debug('Loaded model.')
69
            else:
70
                raise NotInitializedException(
71
                    f'Model not found at {path}',
72
                    backend_id=self.backend_id)
73
74
    def _train(self, corpus, params):
75
        if corpus == 'cached':
76
            raise NotSupportedException(
77
                'Training stwfsapy project from cached data not supported.')
78
        if corpus.is_empty():
79
            raise NotSupportedException(
80
                'Cannot train stwfsapy project with no documents.')
81
        self.debug("Transforming training data.")
82
        X = [doc.text for doc in corpus.documents]
83
        y = [doc.uris for doc in corpus.documents]
84
        graph = Graph()
85
        graph_path = params[_KEY_GRAPH_PATH]
86
        graph.load(graph_path, format=guess_format(graph_path))
87
        new_params = {
88
                key: self.STWFSAPY_PARAMETERS[key](val)
89
                for key, val
90
                in params.items()
91
                if key in self.STWFSAPY_PARAMETERS
92
            }
93
        new_params.pop(_KEY_GRAPH_PATH)
94
        p = StwfsapyPredictor(
95
            graph=graph,
96
            langs=frozenset([params['language']]),
97
            **new_params)
98
        p.fit(X, y)
99
        self._model = p
100
        p.store(os.path.join(self.datadir, self.MODEL_FILE))
101
102
    def _suggest(self, text, params):
103
        self.debug(
104
            f'Suggesting subjects for text "{text[:20]}..." (len={len(text)})')
105
        result = self._model.suggest_proba([text])[0]
106
        suggestions = []
107
        for uri, score in result:
108
            subject_id = self.project.subjects.by_uri(uri)
109
            if subject_id:
110
                label = self.project.subjects[subject_id][1]
111
            else:
112
                label = None
113
            suggestion = SubjectSuggestion(
114
                uri,
115
                label,
116
                None,
117
                score)
118
            suggestions.append(suggestion)
119
        return ListSuggestionResult(suggestions)
120