Passed
Pull Request — master (#663)
by Juho
02:48
created

annif.rest.suggest_batch()   A

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 13
rs 10
c 0
b 0
f 0
1
"""Definitions for REST API operations. These are wired via Connexion to
2
methods defined in the Swagger specification."""
3
4
import importlib
5
6
import connexion
7
8
import annif.registry
9
from annif.corpus import Document, DocumentList, SubjectSet
10
from annif.exception import AnnifException
11
from annif.project import Access
12
from annif.suggestion import SuggestionFilter
13
14
15
def project_not_found_error(project_id):
16
    """return a Connexion error object when a project is not found"""
17
18
    return connexion.problem(
19
        status=404,
20
        title="Project not found",
21
        detail="Project '{}' not found".format(project_id),
22
    )
23
24
25
def server_error(err):
26
    """return a Connexion error object when there is a server error (project
27
    or backend problem)"""
28
29
    return connexion.problem(
30
        status=503, title="Service unavailable", detail=err.format_message()
31
    )
32
33
34
def language_not_supported_error(lang):
35
    """return a Connexion error object when attempting to use unsupported language"""
36
37
    return connexion.problem(
38
        status=400,
39
        title="Bad Request",
40
        detail=f'language "{lang}" not supported by vocabulary',
41
    )
42
43
44
def show_info():
45
    """return version of annif and a title for the api according to Swagger spec"""
46
47
    return {"title": "Annif REST API", "version": importlib.metadata.version("annif")}
48
49
50
def list_projects():
51
    """return a dict with projects formatted according to Swagger spec"""
52
53
    return {
54
        "projects": [
55
            proj.dump()
56
            for proj in annif.registry.get_projects(min_access=Access.public).values()
57
        ]
58
    }
59
60
61
def show_project(project_id):
62
    """return a single project formatted according to Swagger spec"""
63
64
    try:
65
        project = annif.registry.get_project(project_id, min_access=Access.hidden)
66
    except ValueError:
67
        return project_not_found_error(project_id)
68
    return project.dump()
69
70
71
def _suggestion_to_dict(suggestion, subject_index, language):
72
    subject = subject_index[suggestion.subject_id]
73
    return {
74
        "uri": subject.uri,
75
        "label": subject.labels[language],
76
        "notation": subject.notation,
77
        "score": suggestion.score,
78
    }
79
80
81
def suggest(project_id, body):
82
    """suggest subjects for the given text and return a dict with results
83
    formatted according to Swagger spec"""
84
85
    parameters = dict(
86
        (key, body[key]) for key in ["language", "limit", "threshold"] if key in body
87
    )
88
    documents = [{"text": body["text"]}]
89
    result = _suggest(project_id, documents, parameters)
90
91
    if isinstance(result, list):
92
        return result[0]  # successful operation
93
    else:
94
        return result  # connexion problem
95
96
97
def suggest_batch(project_id, body):
98
    """suggest subjects for the given documents and return a list of dicts with results
99
    formatted according to Swagger spec"""
100
101
    parameters = body.get("parameters", {})
102
    result = _suggest(project_id, body["documents"], parameters)
103
104
    if isinstance(result, list):
105
        for ind, doc_results in enumerate(result):
106
            doc_results["id"] = body["documents"][ind].get("id")
107
        return result
108
    else:
109
        return result  # connexion problem
110
111
112
def _suggest(project_id, documents, parameters):
113
    corpus = DocumentList(
114
        [
115
            Document(
116
                text=d["text"],
117
                subject_set=None,
118
            )
119
            for d in documents
120
        ]
121
    )
122
123
    try:
124
        project = annif.registry.get_project(project_id, min_access=Access.hidden)
125
    except ValueError:
126
        return project_not_found_error(project_id)
127
128
    try:
129
        lang = parameters.get("language") or project.vocab_lang
130
    except AnnifException as err:
131
        return server_error(err)
132
133
    if lang not in project.vocab.languages:
134
        return language_not_supported_error(lang)
135
136
    limit = parameters.get("limit", 10)
137
    threshold = parameters.get("threshold", 0.0)
138
139
    try:
140
        hit_filter = SuggestionFilter(project.subjects, limit, threshold)
141
        hit_sets = project.suggest_batch(corpus)
142
    except AnnifException as err:
143
        return server_error(err)
144
145
    return [
146
        {
147
            "results": [
148
                _suggestion_to_dict(hit, project.subjects, lang)
149
                for hit in hit_filter(hits).as_list()
150
            ]
151
        }
152
        for hits in hit_sets
153
    ]
154
155
156
def _documents_to_corpus(documents, subject_index):
157
    corpus = [
158
        Document(
159
            text=d["text"],
160
            subject_set=SubjectSet(
161
                [subject_index.by_uri(subj["uri"]) for subj in d["subjects"]]
162
            ),
163
        )
164
        for d in documents
165
        if "text" in d and "subjects" in d
166
    ]
167
    return DocumentList(corpus)
168
169
170
def learn(project_id, body):
171
    """learn from documents and return an empty 204 response if succesful"""
172
173
    try:
174
        project = annif.registry.get_project(project_id, min_access=Access.hidden)
175
    except ValueError:
176
        return project_not_found_error(project_id)
177
178
    try:
179
        corpus = _documents_to_corpus(body, project.subjects)
180
        project.learn(corpus)
181
    except AnnifException as err:
182
        return server_error(err)
183
184
    return None, 204
185