Passed
Pull Request — master (#1206)
by Konstantin
04:17 queued 01:14
created

ocrd.mets_server   A

Complexity

Total Complexity 35

Size/Duplication

Total Lines 325
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 35
eloc 226
dl 0
loc 325
rs 9.6
c 0
b 0
f 0

23 Methods

Rating   Name   Duplication   Size   Complexity  
A OcrdFileModel.create() 0 3 1
A OcrdFileGroupListModel.create() 0 3 1
A OcrdAgentModel.create() 0 3 1
A OcrdAgentListModel.create() 0 4 1
A OcrdFileListModel.create() 0 12 1
A ClientSideOcrdMets.unique_identifier() 0 3 1
A ClientSideOcrdMets.workspace_path() 0 3 1
A ClientSideOcrdMets.session() 0 3 2
A OcrdMetsServer.shutdown() 0 7 3
A ClientSideOcrdMets.add_agent() 0 2 1
A ClientSideOcrdMets.reload() 0 2 1
A ClientSideOcrdMets.__getattr__() 0 2 1
A ClientSideOcrdMets.find_files() 0 14 5
A ClientSideOcrdMets.stop() 0 6 2
A ClientSideOcrdMets.save() 0 2 1
A ClientSideOcrdMets.__init__() 0 4 3
A ClientSideOcrdMets.agents() 0 6 2
A ClientSideOcrdMets.find_all_files() 0 2 1
A ClientSideOcrdMets.file_groups() 0 3 1
A ClientSideOcrdMets.add_file() 0 19 1
A ClientSideOcrdMets.__str__() 0 2 1
A OcrdMetsServer.__init__() 0 5 1
B OcrdMetsServer.startup() 0 115 2
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response, requests
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
#
25
# Models
26
#
27
28
class OcrdFileModel(BaseModel):
29
    file_grp : str = Field()
30
    file_id : str = Field()
31
    mimetype : str = Field()
32
    page_id : Optional[str] = Field()
33
    url : Optional[str] = Field()
34
    local_filename : Optional[str] = Field()
35
36
    @staticmethod
37
    def create(file_grp : str, file_id : str, page_id : Optional[str], url : Optional[str], local_filename : Optional[Union[str, Path]], mimetype : str):
38
        return OcrdFileModel(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=str(local_filename))
39
40
class OcrdAgentModel(BaseModel):
41
    name : str = Field()
42
    type : str = Field()
43
    role : str = Field()
44
    otherrole : Optional[str] = Field()
45
    othertype : str = Field()
46
    notes : Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
47
48
    @staticmethod
49
    def create(name : str, _type : str, role : str, otherrole : str, othertype : str, notes : List[Tuple[Dict[str, str], Optional[str]]]):
50
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
51
52
53
class OcrdFileListModel(BaseModel):
54
    files : List[OcrdFileModel] = Field()
55
56
    @staticmethod
57
    def create(files : List[OcrdFile]):
58
        ret = OcrdFileListModel(
59
            files=[OcrdFileModel.create(
60
                file_grp=f.fileGrp,
61
                file_id=f.ID,
62
                mimetype=f.mimetype,
63
                page_id=f.pageId,
64
                url=f.url,
65
                local_filename=f.local_filename
66
            ) for f in files])
67
        return ret
68
69
class OcrdFileGroupListModel(BaseModel):
70
    file_groups : List[str] = Field()
71
72
    @staticmethod
73
    def create(file_groups : List[str]):
74
        return OcrdFileGroupListModel(file_groups=file_groups)
75
76
class OcrdAgentListModel(BaseModel):
77
    agents : List[OcrdAgentModel] = Field()
78
79
    @staticmethod
80
    def create(agents : List[OcrdAgent]):
81
        return OcrdAgentListModel(
82
            agents=[OcrdAgentModel.create(name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes) for a in agents]
83
        )
84
85
#
86
# Client
87
#
88
89
90
class ClientSideOcrdMets():
91
    """
92
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
93
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
94
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
95
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
96
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
97
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
98
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
99
    """
100
101
    def __init__(self, url):
102
        self.protocol = 'tcp' if url.startswith('http://') else 'uds'
103
        self.log = getLogger(f'ocrd.mets_client[{url}]')
104
        self.url = url if self.protocol == 'tcp' else f'http+unix://{url.replace("/", "%2F")}'
105
106
    @property
107
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
108
        return requests_session() if self.protocol == 'tcp' else requests_unixsocket_session()
109
110
    def __getattr__(self, name):
111
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
112
113
    def __str__(self):
114
        return f'<ClientSideOcrdMets[url={self.url}]>'
115
116
    @property
117
    def workspace_path(self):
118
        return self.session.request('GET', f'{self.url}/workspace_path').text
119
120
    def reload(self):
121
        return self.session.request('POST', f'{self.url}/reload').text
122
123
    @deprecated_alias(ID="file_id")
124
    @deprecated_alias(pageId="page_id")
125
    @deprecated_alias(fileGrp="file_grp")
126
    def find_files(self, **kwargs):
127
        self.log.debug('find_files(%s)', kwargs)
128
        if 'pageId' in kwargs:
129
            kwargs['page_id'] = kwargs.pop('pageId')
130
        if 'ID' in kwargs:
131
            kwargs['file_id'] = kwargs.pop('ID')
132
        if 'fileGrp' in kwargs:
133
            kwargs['file_grp'] = kwargs.pop('fileGrp')
134
        r = self.session.request('GET', f'{self.url}/file', params={**kwargs})
135
        for f in r.json()['files']:
136
            yield ClientSideOcrdFile(None, ID=f['file_id'], pageId=f['page_id'], fileGrp=f['file_grp'], url=f['url'], local_filename=f['local_filename'], mimetype=f['mimetype'])
137
138
    def find_all_files(self, *args, **kwargs):
139
        return list(self.find_files(*args, **kwargs))
140
141
    def add_agent(self, *args, **kwargs):
142
        return self.session.request('POST', f'{self.url}/agent', json=OcrdAgentModel.create(**kwargs).dict())
143
144
    @property
145
    def agents(self):
146
        agent_dicts = self.session.request('GET', f'{self.url}/agent').json()['agents']
147
        for agent_dict in agent_dicts:
148
            agent_dict['_type'] = agent_dict.pop('type')
149
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
150
151
    @property
152
    def unique_identifier(self):
153
        return self.session.request('GET', f'{self.url}/unique_identifier').text
154
155
    @property
156
    def file_groups(self):
157
        return self.session.request('GET', f'{self.url}/file_groups').json()['file_groups']
158
159
    @deprecated_alias(pageId="page_id")
160
    @deprecated_alias(ID="file_id")
161
    def add_file(self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs):
162
        data = OcrdFileModel.create(
163
            file_id=file_id,
164
            file_grp=file_grp,
165
            page_id=page_id,
166
            mimetype=mimetype,
167
            url=url,
168
            local_filename=local_filename)
169
        r = self.session.request('POST', f'{self.url}/file', data=data.dict())
170
        return ClientSideOcrdFile(
171
                None,
172
                ID=file_id,
173
                fileGrp=file_grp,
174
                url=url,
175
                pageId=page_id,
176
                mimetype=mimetype,
177
                local_filename=local_filename)
178
179
180
    def save(self):
181
        self.session.request('PUT', self.url)
182
183
    def stop(self):
184
        try:
185
            self.session.request('DELETE', self.url)
186
        except ConnectionError:
187
            # Expected because we exit the process without returning
188
            pass
189
190
#
191
# Server
192
#
193
194
class OcrdMetsServer():
195
196
    def __init__(self, workspace, url):
197
        self.workspace = workspace
198
        self.url = url
199
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
200
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
201
202
    def shutdown(self):
203
        if self.is_uds:
204
            if Path(self.url).exists():
205
                self.log.warning(f'UDS socket {self.url} still exists, removing it')
206
                Path(self.url).unlink()
207
        # os._exit because uvicorn catches SystemExit raised by sys.exit
208
        _exit(0)
209
210
    def startup(self):
211
        self.log.info("Starting up METS server")
212
213
        workspace = self.workspace
214
215
        app = FastAPI(
216
            title="OCR-D METS Server",
217
            description="Providing simultaneous write-access to mets.xml for OCR-D",
218
        )
219
220
        @app.exception_handler(ValidationError)
221
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
222
            return JSONResponse(status_code=400, content=exc.errors())
223
224
        @app.exception_handler(FileExistsError)
225
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
226
            return JSONResponse(status_code=400, content=str(exc))
227
228
        @app.exception_handler(re.error)
229
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
230
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
231
232
        @app.get("/file", response_model=OcrdFileListModel)
233
        async def find_files(
234
            file_grp : Optional[str] = None,
235
            file_id : Optional[str] = None,
236
            page_id : Optional[str] = None,
237
            mimetype : Optional[str] = None,
238
            local_filename : Optional[str] = None,
239
            url : Optional[str] = None,
240
        ):
241
            """
242
            Find files in the mets
243
            """
244
            found = workspace.mets.find_all_files(fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url)
245
            return OcrdFileListModel.create(found)
246
247
        @app.put('/')
248
        def save():
249
            return workspace.save_mets()
250
251
        @app.post('/file', response_model=OcrdFileModel)
252
        async def add_file(
253
            file_grp : str = Form(),
254
            file_id : str = Form(),
255
            page_id : Optional[str] = Form(),
256
            mimetype : str = Form(),
257
            url : Optional[str] = Form(None),
258
            local_filename : Optional[str] = Form(None),
259
        ):
260
            """
261
            Add a file
262
            """
263
            # Validate
264
            file_resource = OcrdFileModel.create(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=local_filename)
265
            # Add to workspace
266
            kwargs = file_resource.dict()
267
            workspace.add_file(**kwargs)
268
            return file_resource
269
270
        @app.get('/file_groups', response_model=OcrdFileGroupListModel)
271
        async def file_groups():
272
            return {'file_groups': workspace.mets.file_groups}
273
274
        @app.post('/agent', response_model=OcrdAgentModel)
275
        async def add_agent(agent : OcrdAgentModel):
276
            kwargs = agent.dict()
277
            kwargs['_type'] = kwargs.pop('type')
278
            workspace.mets.add_agent(**kwargs)
279
            return agent
280
281
        @app.get('/agent', response_model=OcrdAgentListModel)
282
        async def agents():
283
            return OcrdAgentListModel.create(workspace.mets.agents)
284
285
        @app.get('/unique_identifier', response_model=str)
286
        async def unique_identifier():
287
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
288
289
        @app.get('/workspace_path', response_model=str)
290
        async def workspace_path():
291
            return Response(content=workspace.directory, media_type="text/plain")
292
293
        @app.post('/reload')
294
        async def workspace_reload_mets():
295
            workspace.reload_mets()
296
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
297
298
        @app.delete('/')
299
        async def stop():
300
            """
301
            Stop the server
302
            """
303
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
304
            workspace.save_mets()
305
            self.shutdown()
306
307
        # ------------- #
308
309
        if self.is_uds:
310
            # Create socket and change to world-readable and -writable to avoid
311
            # permsission errors
312
            self.log.debug(f"chmod 0o677 {self.url}")
313
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
314
            server.bind(self.url)  # creates the socket file
315
            atexit.register(self.shutdown)
316
            server.close()
317
            chmod(self.url, 0o666)
318
            uvicorn_kwargs = {'uds': self.url}
319
        else:
320
            parsed = urlparse(self.url)
321
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
322
323
        self.log.debug("Starting uvicorn")
324
        uvicorn.run(app, **uvicorn_kwargs)
325