Passed
Pull Request — master (#1220)
by
unknown
03:39
created

ocrd.mets_server.MpxReq.agents()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
25
#
26
# Models
27
#
28
29
30
class OcrdFileModel(BaseModel):
31
    file_grp: str = Field()
32
    file_id: str = Field()
33
    mimetype: str = Field()
34
    page_id: Optional[str] = Field()
35
    url: Optional[str] = Field()
36
    local_filename: Optional[str] = Field()
37
38
    @staticmethod
39
    def create(
40
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
41
        local_filename: Optional[Union[str, Path]], mimetype: str
42
    ):
43
        return OcrdFileModel(
44
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
45
            local_filename=str(local_filename)
46
        )
47
48
49
class OcrdAgentModel(BaseModel):
50
    name: str = Field()
51
    type: str = Field()
52
    role: str = Field()
53
    otherrole: Optional[str] = Field()
54
    othertype: str = Field()
55
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
56
57
    @staticmethod
58
    def create(
59
        name: str, _type: str, role: str, otherrole: str, othertype: str,
60
        notes: List[Tuple[Dict[str, str], Optional[str]]]
61
    ):
62
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
63
64
65
class OcrdFileListModel(BaseModel):
66
    files: List[OcrdFileModel] = Field()
67
68
    @staticmethod
69
    def create(files: List[OcrdFile]):
70
        ret = OcrdFileListModel(
71
            files=[
72
                OcrdFileModel.create(
73
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
74
                    local_filename=f.local_filename
75
                ) for f in files
76
            ]
77
        )
78
        return ret
79
80
81
class OcrdFileGroupListModel(BaseModel):
82
    file_groups: List[str] = Field()
83
84
    @staticmethod
85
    def create(file_groups: List[str]):
86
        return OcrdFileGroupListModel(file_groups=file_groups)
87
88
89
class OcrdAgentListModel(BaseModel):
90
    agents: List[OcrdAgentModel] = Field()
91
92
    @staticmethod
93
    def create(agents: List[OcrdAgent]):
94
        return OcrdAgentListModel(
95
            agents=[
96
                OcrdAgentModel.create(
97
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
98
                ) for a in agents
99
            ]
100
        )
101
102
103
#
104
# Client
105
#
106
107
108
class ClientSideOcrdMets:
109
    """
110
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
111
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
112
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
116
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
117
    """
118
119
    def __init__(self, url, workspace_path: Optional[str] = None):
120
        self.protocol = "tcp" if url.startswith("http://") else "uds"
121
        self.log = getLogger(f"ocrd.mets_client[{url}]")
122
        self.url = url if self.protocol == "tcp" else f'http+unix://{url.replace("/", "%2F")}'
123
        self.ws_dir_path = workspace_path if workspace_path else None
124
125
        if self.protocol == "tcp" and "tcp_mets" in self.url:
126
            self.multiplexing_mode = True
127
            if not self.ws_dir_path:
128
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
129
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
130
        else:
131
            self.multiplexing_mode = False
132
133
    @property
134
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
135
        return requests_session() if self.protocol == "tcp" else requests_unixsocket_session()
136
137
    def __getattr__(self, name):
138
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
139
140
    def __str__(self):
141
        return f"<ClientSideOcrdMets[url={self.url}]>"
142
143
    def save(self):
144
        """
145
        Request writing the changes to the file system
146
        """
147
        if not self.multiplexing_mode:
148
            self.session.request("PUT", url=self.url)
149
        else:
150
            self.session.request(
151
                "POST",
152
                self.url,
153
                json=MpxReq.save(self.ws_dir_path)
154
            )
155
156
    def stop(self):
157
        """
158
        Request stopping the mets server
159
        """
160
        try:
161
            if not self.multiplexing_mode:
162
                self.session.request("DELETE", self.url)
163
                return
164
            else:
165
                self.session.request(
166
                    "POST",
167
                    self.url,
168
                    json=MpxReq.stop(self.ws_dir_path)
169
                )
170
        except ConnectionError:
171
            # Expected because we exit the process without returning
172
            pass
173
174
    def reload(self):
175
        """
176
        Request reloading of the mets file from the file system
177
        """
178
        if not self.multiplexing_mode:
179
            return self.session.request("POST", f"{self.url}/reload").text
180
        else:
181
            return self.session.request(
182
                "POST",
183
                self.url,
184
                json=MpxReq.reload(self.ws_dir_path)
185
            ).json()["text"]
186
187
    @property
188
    def unique_identifier(self):
189
        if not self.multiplexing_mode:
190
            return self.session.request("GET", f"{self.url}/unique_identifier").text
191
        else:
192
            return self.session.request(
193
                "POST",
194
                self.url,
195
                json=MpxReq.unique_identifier(self.ws_dir_path)
196
            ).json()["text"]
197
198
    @property
199
    def workspace_path(self):
200
        if not self.multiplexing_mode:
201
            self.ws_dir_path = self.session.request("GET", f"{self.url}/workspace_path").text
202
            return self.ws_dir_path
203
        else:
204
            self.ws_dir_path = self.session.request(
205
                "POST",
206
                self.url,
207
                json=MpxReq.workspace_path(self.ws_dir_path)
208
            ).json()["text"]
209
            return self.ws_dir_path
210
211
    @property
212
    def file_groups(self):
213
        if not self.multiplexing_mode:
214
            return self.session.request("GET", f"{self.url}/file_groups").json()["file_groups"]
215
        else:
216
            return self.session.request(
217
                "POST",
218
                self.url,
219
                json=MpxReq.file_groups(self.ws_dir_path)
220
            ).json()["file_groups"]
221
222
    @property
223
    def agents(self):
224
        if not self.multiplexing_mode:
225
            agent_dicts = self.session.request("GET", f"{self.url}/agent").json()["agents"]
226
        else:
227
            agent_dicts = self.session.request(
228
                "POST",
229
                self.url,
230
                json=MpxReq.agents(self.ws_dir_path)
231
            ).json()["agents"]
232
233
        for agent_dict in agent_dicts:
234
            agent_dict["_type"] = agent_dict.pop("type")
235
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
236
237
    def add_agent(self, *args, **kwargs):
238
        if not self.multiplexing_mode:
239
            return self.session.request("POST", f"{self.url}/agent", json=OcrdAgentModel.create(**kwargs).dict())
240
        else:
241
            self.session.request(
242
                "POST",
243
                self.url,
244
                json=MpxReq.add_agent(self.ws_dir_path, OcrdAgentModel.create(**kwargs).dict())
245
            ).json()
246
            return OcrdAgentModel.create(**kwargs)
247
248
    @deprecated_alias(ID="file_id")
249
    @deprecated_alias(pageId="page_id")
250
    @deprecated_alias(fileGrp="file_grp")
251
    def find_files(self, **kwargs):
252
        self.log.debug("find_files(%s)", kwargs)
253
        if "pageId" in kwargs:
254
            kwargs["page_id"] = kwargs.pop("pageId")
255
        if "ID" in kwargs:
256
            kwargs["file_id"] = kwargs.pop("ID")
257
        if "fileGrp" in kwargs:
258
            kwargs["file_grp"] = kwargs.pop("fileGrp")
259
260
        if not self.multiplexing_mode:
261
            r = self.session.request(method="GET", url=f"{self.url}/file", params={**kwargs})
262
        else:
263
            r = self.session.request(
264
                "POST",
265
                self.url,
266
                json=MpxReq.find_files(self.ws_dir_path, {**kwargs})
267
            )
268
269
        for f in r.json()["files"]:
270
            yield ClientSideOcrdFile(
271
                None, ID=f["file_id"], pageId=f["page_id"], fileGrp=f["file_grp"], url=f["url"],
272
                local_filename=f["local_filename"], mimetype=f["mimetype"]
273
            )
274
275
    def find_all_files(self, *args, **kwargs):
276
        return list(self.find_files(*args, **kwargs))
277
278
    @deprecated_alias(pageId="page_id")
279
    @deprecated_alias(ID="file_id")
280
    def add_file(
281
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
282
    ):
283
        data = OcrdFileModel.create(
284
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
285
            local_filename=local_filename
286
        )
287
288
        if not self.multiplexing_mode:
289
            self.session.request("POST", f"{self.url}/file", data=data.dict())
290
        else:
291
            self.session.request("POST", self.url, json=MpxReq.add_file(self.ws_dir_path, data.dict()))
292
293
        return ClientSideOcrdFile(
294
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
295
            local_filename=local_filename
296
        )
297
298
299
class MpxReq:
300
    """This class wrapps the request bodies needed for the tcp forwarding
301
302
    For every mets-server-call like find_files or workspace_path a special request_body is
303
    needed to call `MetsServerProxy.forward_tcp_request`. These are created by this functions.
304
305
    Reason to put this to a separate class is to allow easier testing
306
    """
307
    @staticmethod
308
    def save(ws_dir_path: str) -> Dict:
309
        return {
310
            "workspace_path": ws_dir_path,
311
            "method_type": "PUT",
312
            "response_type": "empty",
313
            "request_url": "",
314
            "request_data": {}
315
        }
316
317
    @staticmethod
318
    def stop(ws_dir_path: str) -> Dict:
319
        return {
320
            "workspace_path": ws_dir_path,
321
            "method_type": "DELETE",
322
            "response_type": "empty",
323
            "request_url": "",
324
            "request_data": {}
325
        }
326
327
    @staticmethod
328
    def reload(ws_dir_path: str) -> Dict:
329
        return {
330
            "workspace_path": ws_dir_path,
331
            "method_type": "POST",
332
            "response_type": "text",
333
            "request_url": "reload",
334
            "request_data": {}
335
        }
336
337
    @staticmethod
338
    def unique_identifier(ws_dir_path: str) -> Dict:
339
        return {
340
            "workspace_path": ws_dir_path,
341
            "method_type": "GET",
342
            "response_type": "text",
343
            "request_url": "unique_identifier",
344
            "request_data": {}
345
        }
346
347
    @staticmethod
348
    def workspace_path(ws_dir_path: str) -> Dict:
349
        return {
350
            "workspace_path": ws_dir_path,
351
            "method_type": "GET",
352
            "response_type": "text",
353
            "request_url": "workspace_path",
354
            "request_data": {}
355
        }
356
357
    @staticmethod
358
    def file_groups(ws_dir_path: str) -> Dict:
359
        return {
360
            "workspace_path": ws_dir_path,
361
            "method_type": "GET",
362
            "response_type": "dict",
363
            "request_url": "file_groups",
364
            "request_data": {}
365
        }
366
367
    @staticmethod
368
    def agents(ws_dir_path: str) -> Dict:
369
        return {
370
            "workspace_path": ws_dir_path,
371
            "method_type": "GET",
372
            "response_type": "class",
373
            "request_url": "agent",
374
            "request_data": {}
375
        }
376
377
    @staticmethod
378
    def add_agent(ws_dir_path: str, agent_model: Dict) -> Dict:
379
        return {
380
            "workspace_path": ws_dir_path,
381
            "method_type": "POST",
382
            "response_type": "class",
383
            "request_url": "agent",
384
            "request_data": {
385
                "class": agent_model
386
            }
387
        }
388
389
    @staticmethod
390
    def find_files(ws_dir_path: str, params: Dict) -> Dict:
391
        return {
392
            "workspace_path": ws_dir_path,
393
            "method_type": "GET",
394
            "response_type": "class",
395
            "request_url": "file",
396
            "request_data": {
397
                "params": params
398
            }
399
        }
400
401
    @staticmethod
402
    def add_file(ws_dir_path: str, data: Dict) -> Dict:
403
        return {
404
            "workspace_path": ws_dir_path,
405
            "method_type": "POST",
406
            "response_type": "class",
407
            "request_url": "file",
408
            "request_data": {
409
                "form": data
410
            }
411
        }
412
413
#
414
# Server
415
#
416
417
418
class OcrdMetsServer:
419
    def __init__(self, workspace, url):
420
        self.workspace = workspace
421
        self.url = url
422
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
423
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
424
425
    def shutdown(self):
426
        if self.is_uds:
427
            if Path(self.url).exists():
428
                self.log.warning(f'UDS socket {self.url} still exists, removing it')
429
                Path(self.url).unlink()
430
        # os._exit because uvicorn catches SystemExit raised by sys.exit
431
        _exit(0)
432
433
    def startup(self):
434
        self.log.info("Starting up METS server")
435
436
        workspace = self.workspace
437
438
        app = FastAPI(
439
            title="OCR-D METS Server",
440
            description="Providing simultaneous write-access to mets.xml for OCR-D",
441
        )
442
443
        @app.exception_handler(ValidationError)
444
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
445
            return JSONResponse(status_code=400, content=exc.errors())
446
447
        @app.exception_handler(FileExistsError)
448
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
449
            return JSONResponse(status_code=400, content=str(exc))
450
451
        @app.exception_handler(re.error)
452
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
453
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
454
455
        @app.put(path='/')
456
        def save():
457
            """
458
            Write current changes to the file system
459
            """
460
            return workspace.save_mets()
461
462
        @app.delete(path='/')
463
        async def stop():
464
            """
465
            Stop the mets server
466
            """
467
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
468
            workspace.save_mets()
469
            self.shutdown()
470
471
        @app.post(path='/reload')
472
        async def workspace_reload_mets():
473
            """
474
            Reload mets file from the file system
475
            """
476
            workspace.reload_mets()
477
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
478
479
        @app.get(path='/unique_identifier', response_model=str)
480
        async def unique_identifier():
481
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
482
483
        @app.get(path='/workspace_path', response_model=str)
484
        async def workspace_path():
485
            return Response(content=workspace.directory, media_type="text/plain")
486
487
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
488
        async def file_groups():
489
            return {'file_groups': workspace.mets.file_groups}
490
491
        @app.get(path='/agent', response_model=OcrdAgentListModel)
492
        async def agents():
493
            return OcrdAgentListModel.create(workspace.mets.agents)
494
495
        @app.post(path='/agent', response_model=OcrdAgentModel)
496
        async def add_agent(agent: OcrdAgentModel):
497
            kwargs = agent.dict()
498
            kwargs['_type'] = kwargs.pop('type')
499
            workspace.mets.add_agent(**kwargs)
500
            return agent
501
502
        @app.get(path="/file", response_model=OcrdFileListModel)
503
        async def find_files(
504
            file_grp: Optional[str] = None,
505
            file_id: Optional[str] = None,
506
            page_id: Optional[str] = None,
507
            mimetype: Optional[str] = None,
508
            local_filename: Optional[str] = None,
509
            url: Optional[str] = None
510
        ):
511
            """
512
            Find files in the mets
513
            """
514
            found = workspace.mets.find_all_files(
515
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
516
            )
517
            return OcrdFileListModel.create(found)
518
519
        @app.post(path='/file', response_model=OcrdFileModel)
520
        async def add_file(
521
            file_grp: str = Form(),
522
            file_id: str = Form(),
523
            page_id: Optional[str] = Form(),
524
            mimetype: str = Form(),
525
            url: Optional[str] = Form(None),
526
            local_filename: Optional[str] = Form(None)
527
        ):
528
            """
529
            Add a file
530
            """
531
            # Validate
532
            file_resource = OcrdFileModel.create(
533
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
534
                local_filename=local_filename
535
            )
536
            # Add to workspace
537
            kwargs = file_resource.dict()
538
            workspace.add_file(**kwargs)
539
            return file_resource
540
541
        # ------------- #
542
543
        if self.is_uds:
544
            # Create socket and change to world-readable and -writable to avoid permission errors
545
            self.log.debug(f"chmod 0o677 {self.url}")
546
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
547
            server.bind(self.url)  # creates the socket file
548
            atexit.register(self.shutdown)
549
            server.close()
550
            chmod(self.url, 0o666)
551
            uvicorn_kwargs = {'uds': self.url}
552
        else:
553
            parsed = urlparse(self.url)
554
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
555
556
        self.log.debug("Starting uvicorn")
557
        uvicorn.run(app, **uvicorn_kwargs)
558