Passed
Pull Request — master (#1220)
by
unknown
03:38
created

ocrd.mets_server.MpxReq.stop()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
25
#
26
# Models
27
#
28
29
30
class OcrdFileModel(BaseModel):
31
    file_grp: str = Field()
32
    file_id: str = Field()
33
    mimetype: str = Field()
34
    page_id: Optional[str] = Field()
35
    url: Optional[str] = Field()
36
    local_filename: Optional[str] = Field()
37
38
    @staticmethod
39
    def create(
40
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
41
        local_filename: Optional[Union[str, Path]], mimetype: str
42
    ):
43
        return OcrdFileModel(
44
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
45
            local_filename=str(local_filename)
46
        )
47
48
49
class OcrdAgentModel(BaseModel):
50
    name: str = Field()
51
    type: str = Field()
52
    role: str = Field()
53
    otherrole: Optional[str] = Field()
54
    othertype: str = Field()
55
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
56
57
    @staticmethod
58
    def create(
59
        name: str, _type: str, role: str, otherrole: str, othertype: str,
60
        notes: List[Tuple[Dict[str, str], Optional[str]]]
61
    ):
62
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
63
64
65
class OcrdFileListModel(BaseModel):
66
    files: List[OcrdFileModel] = Field()
67
68
    @staticmethod
69
    def create(files: List[OcrdFile]):
70
        ret = OcrdFileListModel(
71
            files=[
72
                OcrdFileModel.create(
73
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
74
                    local_filename=f.local_filename
75
                ) for f in files
76
            ]
77
        )
78
        return ret
79
80
81
class OcrdFileGroupListModel(BaseModel):
82
    file_groups: List[str] = Field()
83
84
    @staticmethod
85
    def create(file_groups: List[str]):
86
        return OcrdFileGroupListModel(file_groups=file_groups)
87
88
89
class OcrdAgentListModel(BaseModel):
90
    agents: List[OcrdAgentModel] = Field()
91
92
    @staticmethod
93
    def create(agents: List[OcrdAgent]):
94
        return OcrdAgentListModel(
95
            agents=[
96
                OcrdAgentModel.create(
97
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
98
                ) for a in agents
99
            ]
100
        )
101
102
103
#
104
# Client
105
#
106
107
108
class ClientSideOcrdMets:
109
    """
110
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
111
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
112
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
116
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
117
    """
118
119
    def __init__(self, url, workspace_path: Optional[str] = None):
120
        self.protocol = "tcp" if url.startswith("http://") else "uds"
121
        self.log = getLogger(f"ocrd.mets_client[{url}]")
122
        self.url = url if self.protocol == "tcp" else f'http+unix://{url.replace("/", "%2F")}'
123
        self.ws_dir_path = workspace_path if workspace_path else None
124
125
        if self.protocol == "tcp" and "tcp_mets" in self.url:
126
            self.multiplexing_mode = True
127
            if not self.ws_dir_path:
128
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
129
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
130
        else:
131
            self.multiplexing_mode = False
132
133
    @property
134
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
135
        return requests_session() if self.protocol == "tcp" else requests_unixsocket_session()
136
137
    def __getattr__(self, name):
138
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
139
140
    def __str__(self):
141
        return f"<ClientSideOcrdMets[url={self.url}]>"
142
143
    def save(self):
144
        """
145
        Request writing the changes to the file system
146
        """
147
        if not self.multiplexing_mode:
148
            self.session.request("PUT", url=self.url)
149
        else:
150
            self.session.request(
151
                "POST",
152
                self.url,
153
                json=MpxReq.save(self.ws_dir_path)
154
            )
155
156
    def stop(self):
157
        """
158
        Request stopping the mets server
159
        """
160
        try:
161
            if not self.multiplexing_mode:
162
                self.session.request("DELETE", self.url)
163
                return
164
            else:
165
                self.session.request(
166
                    "POST",
167
                    self.url,
168
                    json=MpxReq.stop(self.ws_dir_path)
169
                )
170
        except ConnectionError:
171
            # Expected because we exit the process without returning
172
            pass
173
174
    def reload(self):
175
        """
176
        Request reloading of the mets file from the file system
177
        """
178
        if not self.multiplexing_mode:
179
            return self.session.request("POST", f"{self.url}/reload").text
180
        else:
181
            return self.session.request(
182
                "POST",
183
                self.url,
184
                json=MpxReq.reload(self.ws_dir_path)
185
            ).json()["text"]
186
187
    @property
188
    def unique_identifier(self):
189
        if not self.multiplexing_mode:
190
            return self.session.request("GET", f"{self.url}/unique_identifier").text
191
        else:
192
            return self.session.request(
193
                "POST",
194
                self.url,
195
                json=MpxReq.unique_identifier(self.ws_dir_path)
196
            ).json()["text"]
197
198
    @property
199
    def workspace_path(self):
200
        if not self.multiplexing_mode:
201
            self.ws_dir_path = self.session.request("GET", f"{self.url}/workspace_path").text
202
            return self.ws_dir_path
203
        else:
204
            self.ws_dir_path = self.session.request(
205
                "POST",
206
                self.url,
207
                json=MpxReq.workspace_path(self.ws_dir_path)
208
            ).json()["text"]
209
            return self.ws_dir_path
210
211
    @property
212
    def file_groups(self):
213
        if not self.multiplexing_mode:
214
            return self.session.request("GET", f"{self.url}/file_groups").json()["file_groups"]
215
        else:
216
            return self.session.request(
217
                "POST",
218
                self.url,
219
                json=MpxReq.file_groups(self.ws_dir_path)
220
            ).json()["file_groups"]
221
222
    @property
223
    def agents(self):
224
        if not self.multiplexing_mode:
225
            agent_dicts = self.session.request("GET", f"{self.url}/agent").json()["agents"]
226
        else:
227
            agent_dicts = self.session.request(
228
                "POST",
229
                self.url,
230
                json=MpxReq.agents(self.ws_dir_path)
231
            ).json()["agents"]
232
233
        for agent_dict in agent_dicts:
234
            agent_dict["_type"] = agent_dict.pop("type")
235
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
236
237
    def add_agent(self, *args, **kwargs):
238
        if not self.multiplexing_mode:
239
            return self.session.request("POST", f"{self.url}/agent", json=OcrdAgentModel.create(**kwargs).dict())
240
        else:
241
            self.session.request(
242
                "POST",
243
                self.url,
244
                json=MpxReq.add_agent(self.ws_dir_path, OcrdAgentModel.create(**kwargs).dict())
245
            ).json()
246
            return OcrdAgentModel.create(**kwargs)
247
248
    @deprecated_alias(ID="file_id")
249
    @deprecated_alias(pageId="page_id")
250
    @deprecated_alias(fileGrp="file_grp")
251
    def find_files(self, **kwargs):
252
        self.log.debug("find_files(%s)", kwargs)
253
        if "pageId" in kwargs:
254
            kwargs["page_id"] = kwargs.pop("pageId")
255
        if "ID" in kwargs:
256
            kwargs["file_id"] = kwargs.pop("ID")
257
        if "fileGrp" in kwargs:
258
            kwargs["file_grp"] = kwargs.pop("fileGrp")
259
260
        if not self.multiplexing_mode:
261
            r = self.session.request(method="GET", url=f"{self.url}/file", params={**kwargs})
262
        else:
263
            r = self.session.request(
264
                "POST",
265
                self.url,
266
                json=MpxReq.find_files(self.ws_dir_path, {**kwargs})
267
            )
268
269
        for f in r.json()["files"]:
270
            yield ClientSideOcrdFile(
271
                None, ID=f["file_id"], pageId=f["page_id"], fileGrp=f["file_grp"], url=f["url"],
272
                local_filename=f["local_filename"], mimetype=f["mimetype"]
273
            )
274
275
    def find_all_files(self, *args, **kwargs):
276
        return list(self.find_files(*args, **kwargs))
277
278
    @deprecated_alias(pageId="page_id")
279
    @deprecated_alias(ID="file_id")
280
    def add_file(
281
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
282
    ):
283
        data = OcrdFileModel.create(
284
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
285
            local_filename=local_filename
286
        )
287
288
        if not self.multiplexing_mode:
289
            r = self.session.request("POST", f"{self.url}/file", data=data.dict())
290
            if not r:
291
                raise RuntimeError("Add file failed. Please check provided parameters")
292
        else:
293
            r = self.session.request("POST", self.url, json=MpxReq.add_file(self.ws_dir_path, data.dict()))
294
            if "error" in r:
295
                raise RuntimeError(f"Add file failed: Msg: {r['error']}")
296
297
        return ClientSideOcrdFile(
298
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
299
            local_filename=local_filename
300
        )
301
302
303
class MpxReq:
304
    """This class wrapps the request bodies needed for the tcp forwarding
305
306
    For every mets-server-call like find_files or workspace_path a special request_body is
307
    needed to call `MetsServerProxy.forward_tcp_request`. These are created by this functions.
308
309
    Reason to put this to a separate class is to allow easier testing
310
    """
311
312
    @staticmethod
313
    def __args_wrapper(
314
        workspace_path: str, method_type: str, response_type: str, request_url: str, request_data: dict
315
    ) -> Dict:
316
        return {
317
            "workspace_path": workspace_path,
318
            "method_type": method_type,
319
            "response_type": response_type,
320
            "request_url": request_url,
321
            "request_data": request_data
322
        }
323
324
    @staticmethod
325
    def save(ws_dir_path: str) -> Dict:
326
        return MpxReq.__args_wrapper(
327
            ws_dir_path, method_type="PUT", response_type="empty", request_url="", request_data={})
328
329
    @staticmethod
330
    def stop(ws_dir_path: str) -> Dict:
331
        return MpxReq.__args_wrapper(
332
            ws_dir_path, method_type="DELETE", response_type="empty", request_url="", request_data={})
333
334
    @staticmethod
335
    def reload(ws_dir_path: str) -> Dict:
336
        return MpxReq.__args_wrapper(
337
            ws_dir_path, method_type="POST", response_type="text", request_url="reload", request_data={})
338
339
    @staticmethod
340
    def unique_identifier(ws_dir_path: str) -> Dict:
341
        return MpxReq.__args_wrapper(
342
            ws_dir_path, method_type="GET", response_type="text", request_url="unique_identifier", request_data={})
343
344
    @staticmethod
345
    def workspace_path(ws_dir_path: str) -> Dict:
346
        return MpxReq.__args_wrapper(
347
            ws_dir_path, method_type="GET", response_type="text", request_url="workspace_path", request_data={})
348
349
    @staticmethod
350
    def file_groups(ws_dir_path: str) -> Dict:
351
        return MpxReq.__args_wrapper(
352
            ws_dir_path, method_type="GET", response_type="dict", request_url="file_groups", request_data={})
353
354
    @staticmethod
355
    def agents(ws_dir_path: str) -> Dict:
356
        return MpxReq.__args_wrapper(
357
            ws_dir_path, method_type="GET", response_type="class", request_url="agent", request_data={})
358
359
    @staticmethod
360
    def add_agent(ws_dir_path: str, agent_model: Dict) -> Dict:
361
        request_data = {"class": agent_model}
362
        return MpxReq.__args_wrapper(
363
            ws_dir_path, method_type="POST", response_type="class", request_url="agent", request_data=request_data)
364
365
    @staticmethod
366
    def find_files(ws_dir_path: str, params: Dict) -> Dict:
367
        request_data = {"params": params}
368
        return MpxReq.__args_wrapper(
369
            ws_dir_path, method_type="GET", response_type="class", request_url="file", request_data=request_data)
370
371
    @staticmethod
372
    def add_file(ws_dir_path: str, data: Dict) -> Dict:
373
        request_data = {"form": data}
374
        return MpxReq.__args_wrapper(
375
            ws_dir_path, method_type="POST", response_type="class", request_url="file", request_data=request_data)
376
377
#
378
# Server
379
#
380
381
382
class OcrdMetsServer:
383
    def __init__(self, workspace, url):
384
        self.workspace = workspace
385
        self.url = url
386
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
387
        self.log = getLogger(f'ocrd.models.ocrd_mets.server.{self.url}')
388
389
    def shutdown(self):
390
        if self.is_uds:
391
            if Path(self.url).exists():
392
                self.log.debug(f'UDS socket {self.url} still exists, removing it')
393
                Path(self.url).unlink()
394
        # os._exit because uvicorn catches SystemExit raised by sys.exit
395
        _exit(0)
396
397
    def startup(self):
398
        self.log.info("Starting up METS server")
399
400
        workspace = self.workspace
401
402
        app = FastAPI(
403
            title="OCR-D METS Server",
404
            description="Providing simultaneous write-access to mets.xml for OCR-D",
405
        )
406
407
        @app.exception_handler(ValidationError)
408
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
409
            return JSONResponse(status_code=400, content=exc.errors())
410
411
        @app.exception_handler(FileExistsError)
412
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
413
            return JSONResponse(status_code=400, content=str(exc))
414
415
        @app.exception_handler(re.error)
416
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
417
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
418
419
        @app.put(path='/')
420
        def save():
421
            """
422
            Write current changes to the file system
423
            """
424
            return workspace.save_mets()
425
426
        @app.delete(path='/')
427
        async def stop():
428
            """
429
            Stop the mets server
430
            """
431
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
432
            workspace.save_mets()
433
            self.shutdown()
434
435
        @app.post(path='/reload')
436
        async def workspace_reload_mets():
437
            """
438
            Reload mets file from the file system
439
            """
440
            workspace.reload_mets()
441
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
442
443
        @app.get(path='/unique_identifier', response_model=str)
444
        async def unique_identifier():
445
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
446
447
        @app.get(path='/workspace_path', response_model=str)
448
        async def workspace_path():
449
            return Response(content=workspace.directory, media_type="text/plain")
450
451
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
452
        async def file_groups():
453
            return {'file_groups': workspace.mets.file_groups}
454
455
        @app.get(path='/agent', response_model=OcrdAgentListModel)
456
        async def agents():
457
            return OcrdAgentListModel.create(workspace.mets.agents)
458
459
        @app.post(path='/agent', response_model=OcrdAgentModel)
460
        async def add_agent(agent: OcrdAgentModel):
461
            kwargs = agent.dict()
462
            kwargs['_type'] = kwargs.pop('type')
463
            workspace.mets.add_agent(**kwargs)
464
            return agent
465
466
        @app.get(path="/file", response_model=OcrdFileListModel)
467
        async def find_files(
468
            file_grp: Optional[str] = None,
469
            file_id: Optional[str] = None,
470
            page_id: Optional[str] = None,
471
            mimetype: Optional[str] = None,
472
            local_filename: Optional[str] = None,
473
            url: Optional[str] = None
474
        ):
475
            """
476
            Find files in the mets
477
            """
478
            found = workspace.mets.find_all_files(
479
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
480
            )
481
            return OcrdFileListModel.create(found)
482
483
        @app.post(path='/file', response_model=OcrdFileModel)
484
        async def add_file(
485
            file_grp: str = Form(),
486
            file_id: str = Form(),
487
            page_id: Optional[str] = Form(),
488
            mimetype: str = Form(),
489
            url: Optional[str] = Form(None),
490
            local_filename: Optional[str] = Form(None)
491
        ):
492
            """
493
            Add a file
494
            """
495
            # Validate
496
            file_resource = OcrdFileModel.create(
497
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
498
                local_filename=local_filename
499
            )
500
            # Add to workspace
501
            kwargs = file_resource.dict()
502
            workspace.add_file(**kwargs)
503
            return file_resource
504
505
        # ------------- #
506
507
        if self.is_uds:
508
            # Create socket and change to world-readable and -writable to avoid permission errors
509
            self.log.debug(f"chmod 0o677 {self.url}")
510
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
511
            server.bind(self.url)  # creates the socket file
512
            atexit.register(self.shutdown)
513
            server.close()
514
            chmod(self.url, 0o666)
515
            uvicorn_kwargs = {'uds': self.url}
516
        else:
517
            parsed = urlparse(self.url)
518
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
519
        uvicorn_kwargs['log_config'] = None
520
        uvicorn_kwargs['access_log'] = False
521
522
        self.log.debug("Starting uvicorn")
523
        uvicorn.run(app, **uvicorn_kwargs)
524