Passed
Pull Request — master (#1220)
by
unknown
03:00
created

ocrd.mets_server.ClientSideOcrdMets.__str__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
25
#
26
# Models
27
#
28
29
30
class OcrdFileModel(BaseModel):
31
    file_grp: str = Field()
32
    file_id: str = Field()
33
    mimetype: str = Field()
34
    page_id: Optional[str] = Field()
35
    url: Optional[str] = Field()
36
    local_filename: Optional[str] = Field()
37
38
    @staticmethod
39
    def create(
40
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
41
        local_filename: Optional[Union[str, Path]], mimetype: str
42
    ):
43
        return OcrdFileModel(
44
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
45
            local_filename=str(local_filename)
46
        )
47
48
49
class OcrdAgentModel(BaseModel):
50
    name: str = Field()
51
    type: str = Field()
52
    role: str = Field()
53
    otherrole: Optional[str] = Field()
54
    othertype: str = Field()
55
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
56
57
    @staticmethod
58
    def create(
59
        name: str, _type: str, role: str, otherrole: str, othertype: str,
60
        notes: List[Tuple[Dict[str, str], Optional[str]]]
61
    ):
62
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
63
64
65
class OcrdFileListModel(BaseModel):
66
    files: List[OcrdFileModel] = Field()
67
68
    @staticmethod
69
    def create(files: List[OcrdFile]):
70
        ret = OcrdFileListModel(
71
            files=[
72
                OcrdFileModel.create(
73
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
74
                    local_filename=f.local_filename
75
                ) for f in files
76
            ]
77
        )
78
        return ret
79
80
81
class OcrdFileGroupListModel(BaseModel):
82
    file_groups: List[str] = Field()
83
84
    @staticmethod
85
    def create(file_groups: List[str]):
86
        return OcrdFileGroupListModel(file_groups=file_groups)
87
88
89
class OcrdAgentListModel(BaseModel):
90
    agents: List[OcrdAgentModel] = Field()
91
92
    @staticmethod
93
    def create(agents: List[OcrdAgent]):
94
        return OcrdAgentListModel(
95
            agents=[
96
                OcrdAgentModel.create(
97
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
98
                ) for a in agents
99
            ]
100
        )
101
102
103
#
104
# Client
105
#
106
107
108
class ClientSideOcrdMets:
109
    """
110
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
111
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
112
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
116
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
117
    """
118
119
    def __init__(self, url, workspace_path: Optional[str] = None):
120
        self.protocol = "tcp" if url.startswith("http://") else "uds"
121
        self.log = getLogger(f"ocrd.mets_client[{url}]")
122
        self.url = url if self.protocol == "tcp" else f'http+unix://{url.replace("/", "%2F")}'
123
        self.ws_dir_path = workspace_path if workspace_path else None
124
125
        if self.protocol == "tcp" and "tcp_mets" in self.url:
126
            self.multiplexing_mode = True
127
            if not self.ws_dir_path:
128
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
129
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
130
        else:
131
            self.multiplexing_mode = False
132
133
    @property
134
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
135
        return requests_session() if self.protocol == "tcp" else requests_unixsocket_session()
136
137
    def __getattr__(self, name):
138
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
139
140
    def __str__(self):
141
        return f"<ClientSideOcrdMets[url={self.url}]>"
142
143
    def save(self):
144
        """
145
        Request writing the changes to the file system
146
        """
147
        if not self.multiplexing_mode:
148
            self.session.request("PUT", url=self.url)
149
        else:
150
            self.session.request(
151
                "POST",
152
                self.url,
153
                json=MpxReq.save(self.ws_dir_path)
154
            )
155
156
    def stop(self):
157
        """
158
        Request stopping the mets server
159
        """
160
        try:
161
            if not self.multiplexing_mode:
162
                self.session.request("DELETE", self.url)
163
                return
164
            else:
165
                self.session.request(
166
                    "POST",
167
                    self.url,
168
                    json=MpxReq.stop(self.ws_dir_path)
169
                )
170
        except ConnectionError:
171
            # Expected because we exit the process without returning
172
            pass
173
174
    def reload(self):
175
        """
176
        Request reloading of the mets file from the file system
177
        """
178
        if not self.multiplexing_mode:
179
            return self.session.request("POST", f"{self.url}/reload").text
180
        else:
181
            return self.session.request(
182
                "POST",
183
                self.url,
184
                json=MpxReq.reload(self.ws_dir_path)
185
            ).json()["text"]
186
187
    @property
188
    def unique_identifier(self):
189
        if not self.multiplexing_mode:
190
            return self.session.request("GET", f"{self.url}/unique_identifier").text
191
        else:
192
            return self.session.request(
193
                "POST",
194
                self.url,
195
                json=MpxReq.unique_identifier(self.ws_dir_path)
196
            ).json()["text"]
197
198
    @property
199
    def workspace_path(self):
200
        if not self.multiplexing_mode:
201
            self.ws_dir_path = self.session.request("GET", f"{self.url}/workspace_path").text
202
            return self.ws_dir_path
203
        else:
204
            self.ws_dir_path = self.session.request(
205
                "POST",
206
                self.url,
207
                json=MpxReq.workspace_path(self.ws_dir_path)
208
            ).json()["text"]
209
            return self.ws_dir_path
210
211
    @property
212
    def file_groups(self):
213
        if not self.multiplexing_mode:
214
            return self.session.request("GET", f"{self.url}/file_groups").json()["file_groups"]
215
        else:
216
            return self.session.request(
217
                "POST",
218
                self.url,
219
                json=MpxReq.file_groups(self.ws_dir_path)
220
            ).json()["file_groups"]
221
222
    @property
223
    def agents(self):
224
        if not self.multiplexing_mode:
225
            agent_dicts = self.session.request("GET", f"{self.url}/agent").json()["agents"]
226
        else:
227
            agent_dicts = self.session.request(
228
                "POST",
229
                self.url,
230
                json=MpxReq.agents(self.ws_dir_path)
231
            ).json()["agents"]
232
233
        for agent_dict in agent_dicts:
234
            agent_dict["_type"] = agent_dict.pop("type")
235
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
236
237
    def add_agent(self, *args, **kwargs):
238
        if not self.multiplexing_mode:
239
            return self.session.request("POST", f"{self.url}/agent", json=OcrdAgentModel.create(**kwargs).dict())
240
        else:
241
            self.session.request(
242
                "POST",
243
                self.url,
244
                json=MpxReq.add_agent(self.ws_dir_path, OcrdAgentModel.create(**kwargs).dict())
245
            ).json()
246
            return OcrdAgentModel.create(**kwargs)
247
248
    @deprecated_alias(ID="file_id")
249
    @deprecated_alias(pageId="page_id")
250
    @deprecated_alias(fileGrp="file_grp")
251
    def find_files(self, **kwargs):
252
        self.log.debug("find_files(%s)", kwargs)
253
        if "pageId" in kwargs:
254
            kwargs["page_id"] = kwargs.pop("pageId")
255
        if "ID" in kwargs:
256
            kwargs["file_id"] = kwargs.pop("ID")
257
        if "fileGrp" in kwargs:
258
            kwargs["file_grp"] = kwargs.pop("fileGrp")
259
260
        if not self.multiplexing_mode:
261
            r = self.session.request(method="GET", url=f"{self.url}/file", params={**kwargs})
262
        else:
263
            r = self.session.request(
264
                "POST",
265
                self.url,
266
                json=MpxReq.find_files(self.ws_dir_path, {**kwargs})
267
            )
268
269
        for f in r.json()["files"]:
270
            yield ClientSideOcrdFile(
271
                None, ID=f["file_id"], pageId=f["page_id"], fileGrp=f["file_grp"], url=f["url"],
272
                local_filename=f["local_filename"], mimetype=f["mimetype"]
273
            )
274
275
    def find_all_files(self, *args, **kwargs):
276
        return list(self.find_files(*args, **kwargs))
277
278
    @deprecated_alias(pageId="page_id")
279
    @deprecated_alias(ID="file_id")
280
    def add_file(
281
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
282
    ):
283
        data = OcrdFileModel.create(
284
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
285
            local_filename=local_filename
286
        )
287
288
        if not self.multiplexing_mode:
289
            r = self.session.request("POST", f"{self.url}/file", data=data.dict())
290
            if not r:
291
                raise RuntimeError("Add file failed. Please check provided parameters")
292
        else:
293
            r = self.session.request("POST", self.url, json=MpxReq.add_file(self.ws_dir_path, data.dict()))
294
            if "error" in r:
295
                raise RuntimeError(f"Add file failed: Msg: {r['error']}")
296
297
        return ClientSideOcrdFile(
298
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
299
            local_filename=local_filename
300
        )
301
302
303
class MpxReq:
304
    """This class wrapps the request bodies needed for the tcp forwarding
305
306
    For every mets-server-call like find_files or workspace_path a special request_body is
307
    needed to call `MetsServerProxy.forward_tcp_request`. These are created by this functions.
308
309
    Reason to put this to a separate class is to allow easier testing
310
    """
311
    @staticmethod
312
    def save(ws_dir_path: str) -> Dict:
313
        return {
314
            "workspace_path": ws_dir_path,
315
            "method_type": "PUT",
316
            "response_type": "empty",
317
            "request_url": "",
318
            "request_data": {}
319
        }
320
321
    @staticmethod
322
    def stop(ws_dir_path: str) -> Dict:
323
        return {
324
            "workspace_path": ws_dir_path,
325
            "method_type": "DELETE",
326
            "response_type": "empty",
327
            "request_url": "",
328
            "request_data": {}
329
        }
330
331
    @staticmethod
332
    def reload(ws_dir_path: str) -> Dict:
333
        return {
334
            "workspace_path": ws_dir_path,
335
            "method_type": "POST",
336
            "response_type": "text",
337
            "request_url": "reload",
338
            "request_data": {}
339
        }
340
341
    @staticmethod
342
    def unique_identifier(ws_dir_path: str) -> Dict:
343
        return {
344
            "workspace_path": ws_dir_path,
345
            "method_type": "GET",
346
            "response_type": "text",
347
            "request_url": "unique_identifier",
348
            "request_data": {}
349
        }
350
351
    @staticmethod
352
    def workspace_path(ws_dir_path: str) -> Dict:
353
        return {
354
            "workspace_path": ws_dir_path,
355
            "method_type": "GET",
356
            "response_type": "text",
357
            "request_url": "workspace_path",
358
            "request_data": {}
359
        }
360
361
    @staticmethod
362
    def file_groups(ws_dir_path: str) -> Dict:
363
        return {
364
            "workspace_path": ws_dir_path,
365
            "method_type": "GET",
366
            "response_type": "dict",
367
            "request_url": "file_groups",
368
            "request_data": {}
369
        }
370
371
    @staticmethod
372
    def agents(ws_dir_path: str) -> Dict:
373
        return {
374
            "workspace_path": ws_dir_path,
375
            "method_type": "GET",
376
            "response_type": "class",
377
            "request_url": "agent",
378
            "request_data": {}
379
        }
380
381
    @staticmethod
382
    def add_agent(ws_dir_path: str, agent_model: Dict) -> Dict:
383
        return {
384
            "workspace_path": ws_dir_path,
385
            "method_type": "POST",
386
            "response_type": "class",
387
            "request_url": "agent",
388
            "request_data": {
389
                "class": agent_model
390
            }
391
        }
392
393
    @staticmethod
394
    def find_files(ws_dir_path: str, params: Dict) -> Dict:
395
        return {
396
            "workspace_path": ws_dir_path,
397
            "method_type": "GET",
398
            "response_type": "class",
399
            "request_url": "file",
400
            "request_data": {
401
                "params": params
402
            }
403
        }
404
405
    @staticmethod
406
    def add_file(ws_dir_path: str, data: Dict) -> Dict:
407
        return {
408
            "workspace_path": ws_dir_path,
409
            "method_type": "POST",
410
            "response_type": "class",
411
            "request_url": "file",
412
            "request_data": {
413
                "form": data
414
            }
415
        }
416
417
#
418
# Server
419
#
420
421
422
class OcrdMetsServer:
423
    def __init__(self, workspace, url):
424
        self.workspace = workspace
425
        self.url = url
426
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
427
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
428
429
    def shutdown(self):
430
        if self.is_uds:
431
            if Path(self.url).exists():
432
                self.log.warning(f'UDS socket {self.url} still exists, removing it')
433
                Path(self.url).unlink()
434
        # os._exit because uvicorn catches SystemExit raised by sys.exit
435
        _exit(0)
436
437
    def startup(self):
438
        self.log.info("Starting up METS server")
439
440
        workspace = self.workspace
441
442
        app = FastAPI(
443
            title="OCR-D METS Server",
444
            description="Providing simultaneous write-access to mets.xml for OCR-D",
445
        )
446
447
        @app.exception_handler(ValidationError)
448
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
449
            return JSONResponse(status_code=400, content=exc.errors())
450
451
        @app.exception_handler(FileExistsError)
452
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
453
            return JSONResponse(status_code=400, content=str(exc))
454
455
        @app.exception_handler(re.error)
456
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
457
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
458
459
        @app.put(path='/')
460
        def save():
461
            """
462
            Write current changes to the file system
463
            """
464
            return workspace.save_mets()
465
466
        @app.delete(path='/')
467
        async def stop():
468
            """
469
            Stop the mets server
470
            """
471
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
472
            workspace.save_mets()
473
            self.shutdown()
474
475
        @app.post(path='/reload')
476
        async def workspace_reload_mets():
477
            """
478
            Reload mets file from the file system
479
            """
480
            workspace.reload_mets()
481
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
482
483
        @app.get(path='/unique_identifier', response_model=str)
484
        async def unique_identifier():
485
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
486
487
        @app.get(path='/workspace_path', response_model=str)
488
        async def workspace_path():
489
            return Response(content=workspace.directory, media_type="text/plain")
490
491
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
492
        async def file_groups():
493
            return {'file_groups': workspace.mets.file_groups}
494
495
        @app.get(path='/agent', response_model=OcrdAgentListModel)
496
        async def agents():
497
            return OcrdAgentListModel.create(workspace.mets.agents)
498
499
        @app.post(path='/agent', response_model=OcrdAgentModel)
500
        async def add_agent(agent: OcrdAgentModel):
501
            kwargs = agent.dict()
502
            kwargs['_type'] = kwargs.pop('type')
503
            workspace.mets.add_agent(**kwargs)
504
            return agent
505
506
        @app.get(path="/file", response_model=OcrdFileListModel)
507
        async def find_files(
508
            file_grp: Optional[str] = None,
509
            file_id: Optional[str] = None,
510
            page_id: Optional[str] = None,
511
            mimetype: Optional[str] = None,
512
            local_filename: Optional[str] = None,
513
            url: Optional[str] = None
514
        ):
515
            """
516
            Find files in the mets
517
            """
518
            found = workspace.mets.find_all_files(
519
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
520
            )
521
            return OcrdFileListModel.create(found)
522
523
        @app.post(path='/file', response_model=OcrdFileModel)
524
        async def add_file(
525
            file_grp: str = Form(),
526
            file_id: str = Form(),
527
            page_id: Optional[str] = Form(),
528
            mimetype: str = Form(),
529
            url: Optional[str] = Form(None),
530
            local_filename: Optional[str] = Form(None)
531
        ):
532
            """
533
            Add a file
534
            """
535
            # Validate
536
            file_resource = OcrdFileModel.create(
537
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
538
                local_filename=local_filename
539
            )
540
            # Add to workspace
541
            kwargs = file_resource.dict()
542
            workspace.add_file(**kwargs)
543
            return file_resource
544
545
        # ------------- #
546
547
        if self.is_uds:
548
            # Create socket and change to world-readable and -writable to avoid permission errors
549
            self.log.debug(f"chmod 0o677 {self.url}")
550
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
551
            server.bind(self.url)  # creates the socket file
552
            atexit.register(self.shutdown)
553
            server.close()
554
            chmod(self.url, 0o666)
555
            uvicorn_kwargs = {'uds': self.url}
556
        else:
557
            parsed = urlparse(self.url)
558
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
559
560
        self.log.debug("Starting uvicorn")
561
        uvicorn.run(app, **uvicorn_kwargs)
562