Completed
Push — master ( 5372d7...4d0462 )
by Osma
18s queued 13s
created

annif.backend.maui.MauiBackend.train()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 9
rs 9.95
c 0
b 0
f 0
cc 2
nop 2
1
"""Maui backend that makes calls to a Maui Server instance using its API"""
2
3
4
import time
5
import os.path
6
import json
7
import requests
8
import requests.exceptions
9
from annif.exception import ConfigurationException
10
from annif.exception import NotSupportedException
11
from annif.exception import OperationFailedException
12
from annif.suggestion import SubjectSuggestion, ListSuggestionResult
13
from . import backend
14
15
16
class MauiBackend(backend.AnnifBackend):
17
    name = "maui"
18
19
    TRAIN_FILE = 'maui-train.jsonl'
20
21
    def endpoint(self, params):
22
        try:
23
            return params['endpoint']
24
        except KeyError:
25
            raise ConfigurationException(
26
                "endpoint must be set in project configuration",
27
                backend_id=self.backend_id)
28
29
    def tagger(self, params):
30
        try:
31
            return params['tagger']
32
        except KeyError:
33
            raise ConfigurationException(
34
                "tagger must be set in project configuration",
35
                backend_id=self.backend_id)
36
37
    def tagger_url(self, params):
38
        return self.endpoint(params) + self.tagger(params)
39
40
    def _initialize_tagger(self, params):
41
        self.info(
42
            "Initializing Maui Service tagger '{}'".format(
43
                self.tagger(params)))
44
45
        # try to delete the tagger in case it already exists
46
        resp = requests.delete(self.tagger_url(params))
47
        self.debug("Trying to delete tagger {} returned status code {}"
48
                   .format(self.tagger(params), resp.status_code))
49
50
        # create a new tagger
51
        data = {'id': self.tagger(params), 'lang': params['language']}
52
        try:
53
            resp = requests.post(self.endpoint(params), data=data)
54
            self.debug("Trying to create tagger {} returned status code {}"
55
                       .format(self.tagger(params), resp.status_code))
56
            resp.raise_for_status()
57
        except requests.exceptions.RequestException as err:
58
            raise OperationFailedException(err)
59
60
    def _upload_vocabulary(self, params):
61
        self.info("Uploading vocabulary")
62
        try:
63
            resp = requests.put(self.tagger_url(params) + '/vocab',
64
                                data=self.project.vocab.as_skos())
65
            resp.raise_for_status()
66
        except requests.exceptions.RequestException as err:
67
            raise OperationFailedException(err)
68
69
    def _create_train_file(self, corpus):
70
        self.info("Creating train file")
71
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
72
        with open(train_path, 'w') as train_file:
73
            for doc in corpus.documents:
74
                doc_obj = {'content': doc.text, 'topics': list(doc.labels)}
75
                json_doc = json.dumps(doc_obj)
76
                print(json_doc, file=train_file)
77
78
    def _upload_train_file(self, params):
79
        self.info("Uploading training documents")
80
        train_path = os.path.join(self.datadir, self.TRAIN_FILE)
81
        with open(train_path, 'rb') as train_file:
82
            try:
83
                resp = requests.post(self.tagger_url(params) + '/train',
84
                                     data=train_file)
85
                resp.raise_for_status()
86
            except requests.exceptions.RequestException as err:
87
                raise OperationFailedException(err)
88
89
    def _wait_for_train(self, params):
90
        self.info("Waiting for training to be completed...")
91
        while True:
92
            try:
93
                resp = requests.get(self.tagger_url(params) + "/train")
94
                resp.raise_for_status()
95
            except requests.exceptions.RequestException as err:
96
                raise OperationFailedException(err)
97
98
            response = resp.json()
99
            if response['completed']:
100
                self.info("Training completed.")
101
                return
102
            time.sleep(1)
103
104
    def _train(self, corpus, params):
105
        if corpus.is_empty():
106
            raise NotSupportedException('training backend {} with no documents'
107
                                        .format(self.backend_id))
108
        self._initialize_tagger(params)
109
        self._upload_vocabulary(params)
110
        self._create_train_file(corpus)
111
        self._upload_train_file(params)
112
        self._wait_for_train(params)
113
114
    def _suggest_request(self, text, params):
115
        data = {'text': text}
116
        headers = {"Content-Type":
117
                   "application/x-www-form-urlencoded; charset=UTF-8"}
118
119
        try:
120
            resp = requests.post(self.tagger_url(params) + '/suggest',
121
                                 data=data,
122
                                 headers=headers)
123
            resp.raise_for_status()
124
        except requests.exceptions.RequestException as err:
125
            self.warning("HTTP request failed: {}".format(err))
126
            return None
127
128
        try:
129
            return resp.json()
130
        except ValueError as err:
131
            self.warning("JSON decode failed: {}".format(err))
132
            return None
133
134
    def _response_to_result(self, response):
135
        try:
136
            return ListSuggestionResult(
137
                [SubjectSuggestion(uri=h['id'],
138
                                   label=h['label'],
139
                                   score=h['probability'])
140
                 for h in response['topics']
141
                 if h['probability'] > 0.0], self.project.subjects)
142
        except (TypeError, ValueError) as err:
143
            self.warning("Problem interpreting JSON data: {}".format(err))
144
            return ListSuggestionResult([], self.project.subjects)
145
146
    def _suggest(self, text, params):
147
        response = self._suggest_request(text, params)
148
        if response:
149
            return self._response_to_result(response)
150
        else:
151
            return ListSuggestionResult([], self.project.subjects)
152