Completed
Pull Request — master (#344)
by Osma
06:44
created

annif.backend.maui   A

Complexity

Total Complexity 28

Size/Duplication

Total Lines 149
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 123
dl 0
loc 149
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A MauiBackend._create_train_file() 0 8 3
A MauiBackend.train() 0 9 2
A MauiBackend._wait_for_train() 0 14 4
A MauiBackend.tagger_url() 0 3 1
A MauiBackend.tagger() 0 8 2
A MauiBackend.endpoint() 0 8 2
A MauiBackend._upload_vocabulary() 0 8 2
A MauiBackend._upload_train_file() 0 10 3
A MauiBackend._initialize_tagger() 0 17 2
A MauiBackend._suggest() 0 6 2
A MauiBackend._response_to_result() 0 11 2
A MauiBackend._suggest_request() 0 15 3
1
"""Maui backend that makes calls to a Maui Server instance using its API"""
2
3
4
import time
5
import os.path
6
import json
7
import requests
8
import requests.exceptions
9
from annif.exception import ConfigurationException
10
from annif.exception import NotSupportedException
11
from annif.exception import OperationFailedException
12
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
13
from . import backend
14
15
16
class MauiBackend(backend.AnnifBackend):
17
    name = "maui"
18
19
    TRAIN_FILE = 'maui-train.jsonl'
20
21
    @property
22
    def endpoint(self):
23
        try:
24
            return self.params['endpoint']
25
        except KeyError:
26
            raise ConfigurationException(
27
                "endpoint must be set in project configuration",
28
                backend_id=self.backend_id)
29
30
    @property
31
    def tagger(self):
32
        try:
33
            return self.params['tagger']
34
        except KeyError:
35
            raise ConfigurationException(
36
                "tagger must be set in project configuration",
37
                backend_id=self.backend_id)
38
39
    @property
40
    def tagger_url(self):
41
        return self.endpoint + self.tagger
42
43
    def _initialize_tagger(self):
44
        self.info("Initializing Maui Service tagger '{}'".format(self.tagger))
45
46
        # try to delete the tagger in case it already exists
47
        resp = requests.delete(self.tagger_url)
48
        self.debug("Trying to delete tagger {} returned status code {}"
49
                   .format(self.tagger, resp.status_code))
50
51
        # create a new tagger
52
        data = {'id': self.tagger, 'lang': self.params['language']}
53
        try:
54
            resp = requests.post(self.endpoint, data=data)
55
            self.debug("Trying to create tagger {} returned status code {}"
56
                       .format(self.tagger, resp.status_code))
57
            resp.raise_for_status()
58
        except requests.exceptions.RequestException as err:
59
            raise OperationFailedException(err)
60
61
    def _upload_vocabulary(self, project):
62
        self.info("Uploading vocabulary")
63
        try:
64
            resp = requests.put(self.tagger_url + '/vocab',
65
                                data=project.vocab.as_skos())
66
            resp.raise_for_status()
67
        except requests.exceptions.RequestException as err:
68
            raise OperationFailedException(err)
69
70
    def _create_train_file(self, corpus):
71
        self.info("Creating train file")
72
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
73
        with open(train_path, 'w') as train_file:
74
            for doc in corpus.documents:
75
                doc_obj = {'content': doc.text, 'topics': list(doc.labels)}
76
                json_doc = json.dumps(doc_obj)
77
                print(json_doc, file=train_file)
78
79
    def _upload_train_file(self):
80
        self.info("Uploading training documents")
81
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
82
        with open(train_path, 'rb') as train_file:
83
            try:
84
                resp = requests.post(self.tagger_url + '/train',
85
                                     data=train_file)
86
                resp.raise_for_status()
87
            except requests.exceptions.RequestException as err:
88
                raise OperationFailedException(err)
89
90
    def _wait_for_train(self):
91
        self.info("Waiting for training to be completed...")
92
        while True:
93
            try:
94
                resp = requests.get(self.tagger_url + "/train")
95
                resp.raise_for_status()
96
            except requests.exceptions.RequestException as err:
97
                raise OperationFailedException(err)
98
99
            response = resp.json()
100
            if response['completed']:
101
                self.info("Training completed.")
102
                return
103
            time.sleep(1)
104
105
    def train(self, corpus, project):
106
        if corpus.is_empty():
107
            raise NotSupportedException('training backend {} with no documents'
108
                                        .format(self.backend_id))
109
        self._initialize_tagger()
110
        self._upload_vocabulary(project)
111
        self._create_train_file(corpus)
112
        self._upload_train_file()
113
        self._wait_for_train()
114
115
    def _suggest_request(self, text):
116
        data = {'text': text}
117
118
        try:
119
            resp = requests.post(self.tagger_url + '/suggest', data=data)
120
            resp.raise_for_status()
121
        except requests.exceptions.RequestException as err:
122
            self.warning("HTTP request failed: {}".format(err))
123
            return None
124
125
        try:
126
            return resp.json()
127
        except ValueError as err:
128
            self.warning("JSON decode failed: {}".format(err))
129
            return None
130
131
    def _response_to_result(self, response, project):
132
        try:
133
            return ListSuggestionResult(
134
                [SubjectSuggestion(uri=h['id'],
135
                                   label=h['label'],
136
                                   score=h['probability'])
137
                 for h in response['topics']
138
                 if h['probability'] > 0.0], project.subjects)
139
        except (TypeError, ValueError) as err:
140
            self.warning("Problem interpreting JSON data: {}".format(err))
141
            return ListSuggestionResult([], project.subjects)
142
143
    def _suggest(self, text, project, params):
144
        response = self._suggest_request(text)
145
        if response:
146
            return self._response_to_result(response, project)
147
        else:
148
            return ListSuggestionResult([], project.subjects)
149