Passed
Pull Request — master (#1240)
by
unknown
02:42
created

ocrd.mets_server.ClientSideOcrdMets.__getattr__()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from time import sleep
8
from pathlib import Path
9
from subprocess import Popen, run as subprocess_run
10
from urllib.parse import urlparse
11
import socket
12
import atexit
13
14
from fastapi import FastAPI, Request, Form, Response
15
from fastapi.responses import JSONResponse
16
from requests import Session as requests_session
17
from requests.exceptions import ConnectionError
18
from requests_unixsocket import Session as requests_unixsocket_session
19
from pydantic import BaseModel, Field, ValidationError
20
21
import uvicorn
22
23
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
24
from ocrd_utils import getLogger
25
26
27
#
28
# Models
29
#
30
31
32
class OcrdFileModel(BaseModel):
33
    file_grp: str = Field()
34
    file_id: str = Field()
35
    mimetype: str = Field()
36
    page_id: Optional[str] = Field()
37
    url: Optional[str] = Field()
38
    local_filename: Optional[str] = Field()
39
40
    @staticmethod
41
    def create(
42
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
43
        local_filename: Optional[Union[str, Path]], mimetype: str
44
    ):
45
        return OcrdFileModel(
46
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
47
            local_filename=str(local_filename)
48
        )
49
50
51
class OcrdAgentModel(BaseModel):
52
    name: str = Field()
53
    type: str = Field()
54
    role: str = Field()
55
    otherrole: Optional[str] = Field()
56
    othertype: str = Field()
57
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
58
59
    @staticmethod
60
    def create(
61
        name: str, _type: str, role: str, otherrole: str, othertype: str,
62
        notes: List[Tuple[Dict[str, str], Optional[str]]]
63
    ):
64
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
65
66
67
class OcrdFileListModel(BaseModel):
68
    files: List[OcrdFileModel] = Field()
69
70
    @staticmethod
71
    def create(files: List[OcrdFile]):
72
        ret = OcrdFileListModel(
73
            files=[
74
                OcrdFileModel.create(
75
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
76
                    local_filename=f.local_filename
77
                ) for f in files
78
            ]
79
        )
80
        return ret
81
82
83
class OcrdFileGroupListModel(BaseModel):
84
    file_groups: List[str] = Field()
85
86
    @staticmethod
87
    def create(file_groups: List[str]):
88
        return OcrdFileGroupListModel(file_groups=file_groups)
89
90
91
class OcrdPageListModel(BaseModel):
92
    physical_pages: List[str] = Field()
93
94
    @staticmethod
95
    def create(physical_pages: List[str]):
96
        return OcrdPageListModel(physical_pages=physical_pages)
97
98
99
class OcrdAgentListModel(BaseModel):
100
    agents: List[OcrdAgentModel] = Field()
101
102
    @staticmethod
103
    def create(agents: List[OcrdAgent]):
104
        return OcrdAgentListModel(
105
            agents=[
106
                OcrdAgentModel.create(
107
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
108
                ) for a in agents
109
            ]
110
        )
111
112
113
#
114
# Client
115
#
116
117
118
class ClientSideOcrdMets:
119
    """
120
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
121
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
122
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
123
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
124
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
125
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
126
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
127
    """
128
129
    def __init__(self, url, workspace_path: Optional[str] = None):
130
        self.protocol = "tcp" if url.startswith("http://") else "uds"
131
        self.log = getLogger(f"ocrd.models.ocrd_mets.client.{url}")
132
        self.url = url if self.protocol == "tcp" else f'http+unix://{url.replace("/", "%2F")}'
133
        self.ws_dir_path = workspace_path if workspace_path else None
134
135
        if self.protocol == "tcp" and "tcp_mets" in self.url:
136
            self.multiplexing_mode = True
137
            if not self.ws_dir_path:
138
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
139
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
140
        else:
141
            self.multiplexing_mode = False
142
143
    @property
144
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
145
        return requests_session() if self.protocol == "tcp" else requests_unixsocket_session()
146
147
    def __getattr__(self, name):
148
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
149
150
    def __str__(self):
151
        return f"<ClientSideOcrdMets[url={self.url}]>"
152
153
    def save(self):
154
        """
155
        Request writing the changes to the file system
156
        """
157
        if not self.multiplexing_mode:
158
            self.session.request("PUT", url=self.url)
159
        else:
160
            self.session.request(
161
                "POST",
162
                self.url,
163
                json=MpxReq.save(self.ws_dir_path)
164
            )
165
166
    def stop(self):
167
        """
168
        Request stopping the mets server
169
        """
170
        try:
171
            if not self.multiplexing_mode:
172
                self.session.request("DELETE", self.url)
173
                return
174
            else:
175
                self.session.request(
176
                    "POST",
177
                    self.url,
178
                    json=MpxReq.stop(self.ws_dir_path)
179
                )
180
        except ConnectionError:
181
            # Expected because we exit the process without returning
182
            pass
183
184
    def reload(self):
185
        """
186
        Request reloading of the mets file from the file system
187
        """
188
        if not self.multiplexing_mode:
189
            return self.session.request("POST", f"{self.url}/reload").text
190
        else:
191
            return self.session.request(
192
                "POST",
193
                self.url,
194
                json=MpxReq.reload(self.ws_dir_path)
195
            ).json()["text"]
196
197
    @property
198
    def unique_identifier(self):
199
        if not self.multiplexing_mode:
200
            return self.session.request("GET", f"{self.url}/unique_identifier").text
201
        else:
202
            return self.session.request(
203
                "POST",
204
                self.url,
205
                json=MpxReq.unique_identifier(self.ws_dir_path)
206
            ).json()["text"]
207
208
    @property
209
    def workspace_path(self):
210
        if not self.multiplexing_mode:
211
            self.ws_dir_path = self.session.request("GET", f"{self.url}/workspace_path").text
212
            return self.ws_dir_path
213
        else:
214
            self.ws_dir_path = self.session.request(
215
                "POST",
216
                self.url,
217
                json=MpxReq.workspace_path(self.ws_dir_path)
218
            ).json()["text"]
219
            return self.ws_dir_path
220
221
    @property
222
    def physical_pages(self) -> List[str]:
223
        if not self.multiplexing_mode:
224
            return self.session.request("GET", f"{self.url}/physical_pages").json()["physical_pages"]
225
        else:
226
            return self.session.request(
227
                "POST",
228
                self.url,
229
                json=MpxReq.physical_pages(self.ws_dir_path)
230
            ).json()["physical_pages"]
231
232
    @property
233
    def file_groups(self):
234
        if not self.multiplexing_mode:
235
            return self.session.request("GET", f"{self.url}/file_groups").json()["file_groups"]
236
        else:
237
            return self.session.request(
238
                "POST",
239
                self.url,
240
                json=MpxReq.file_groups(self.ws_dir_path)
241
            ).json()["file_groups"]
242
243
    @property
244
    def agents(self):
245
        if not self.multiplexing_mode:
246
            agent_dicts = self.session.request("GET", f"{self.url}/agent").json()["agents"]
247
        else:
248
            agent_dicts = self.session.request(
249
                "POST",
250
                self.url,
251
                json=MpxReq.agents(self.ws_dir_path)
252
            ).json()["agents"]
253
254
        for agent_dict in agent_dicts:
255
            agent_dict["_type"] = agent_dict.pop("type")
256
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
257
258
    def add_agent(self, **kwargs):
259
        if not self.multiplexing_mode:
260
            return self.session.request("POST", f"{self.url}/agent", json=OcrdAgentModel.create(**kwargs).dict())
261
        else:
262
            self.session.request(
263
                "POST",
264
                self.url,
265
                json=MpxReq.add_agent(self.ws_dir_path, OcrdAgentModel.create(**kwargs).dict())
266
            ).json()
267
            return OcrdAgentModel.create(**kwargs)
268
269
    def find_files(self, **kwargs):
270
        self.log.debug("find_files(%s)", kwargs)
271
        # translate from native OcrdMets kwargs to OcrdMetsServer REST params
272
        if "pageId" in kwargs:
273
            kwargs["page_id"] = kwargs.pop("pageId")
274
        if "ID" in kwargs:
275
            kwargs["file_id"] = kwargs.pop("ID")
276
        if "fileGrp" in kwargs:
277
            kwargs["file_grp"] = kwargs.pop("fileGrp")
278
279
        if not self.multiplexing_mode:
280
            r = self.session.request(method="GET", url=f"{self.url}/file", params={**kwargs})
281
        else:
282
            r = self.session.request(
283
                "POST",
284
                self.url,
285
                json=MpxReq.find_files(self.ws_dir_path, {**kwargs})
286
            )
287
288
        for f in r.json()["files"]:
289
            yield ClientSideOcrdFile(
290
                None, ID=f["file_id"], pageId=f["page_id"], fileGrp=f["file_grp"], url=f["url"],
291
                local_filename=f["local_filename"], mimetype=f["mimetype"]
292
            )
293
294
    def find_all_files(self, *args, **kwargs):
295
        return list(self.find_files(*args, **kwargs))
296
297
    def add_file(
298
        self, file_grp, content=None, ID=None, url=None, local_filename=None, mimetype=None, pageId=None, **kwargs
299
    ):
300
        data = OcrdFileModel.create(
301
            file_grp=file_grp,
302
            # translate from native OcrdMets kwargs to OcrdMetsServer REST params
303
            file_id=ID, page_id=pageId,
304
            mimetype=mimetype, url=url, local_filename=local_filename
305
        )
306
        # add force+ignore
307
        kwargs = {**kwargs, **data.dict()}
308
309
        if not self.multiplexing_mode:
310
            r = self.session.request("POST", f"{self.url}/file", data=kwargs)
311
            if not r.ok:
312
                raise RuntimeError(f"Failed to add file ({str(data)}): {r.json()}")
313
        else:
314
            r = self.session.request("POST", self.url, json=MpxReq.add_file(self.ws_dir_path, kwargs))
315
            if not r.ok:
316
                raise RuntimeError(f"Failed to add file ({str(data)}): {r.json()[errors]}")
317
318
        return ClientSideOcrdFile(
319
            None, fileGrp=file_grp,
320
            ID=ID, pageId=pageId,
321
            url=url, mimetype=mimetype, local_filename=local_filename
322
        )
323
324
325
class MpxReq:
326
    """This class wrapps the request bodies needed for the tcp forwarding
327
328
    For every mets-server-call like find_files or workspace_path a special request_body is
329
    needed to call `MetsServerProxy.forward_tcp_request`. These are created by this functions.
330
331
    Reason to put this to a separate class is to allow easier testing
332
    """
333
334
    @staticmethod
335
    def __args_wrapper(
336
        workspace_path: str, method_type: str, response_type: str, request_url: str, request_data: dict
337
    ) -> Dict:
338
        return {
339
            "workspace_path": workspace_path,
340
            "method_type": method_type,
341
            "response_type": response_type,
342
            "request_url": request_url,
343
            "request_data": request_data
344
        }
345
346
    @staticmethod
347
    def save(ws_dir_path: str) -> Dict:
348
        return MpxReq.__args_wrapper(
349
            ws_dir_path, method_type="PUT", response_type="empty", request_url="", request_data={})
350
351
    @staticmethod
352
    def stop(ws_dir_path: str) -> Dict:
353
        return MpxReq.__args_wrapper(
354
            ws_dir_path, method_type="DELETE", response_type="empty", request_url="", request_data={})
355
356
    @staticmethod
357
    def reload(ws_dir_path: str) -> Dict:
358
        return MpxReq.__args_wrapper(
359
            ws_dir_path, method_type="POST", response_type="text", request_url="reload", request_data={})
360
361
    @staticmethod
362
    def unique_identifier(ws_dir_path: str) -> Dict:
363
        return MpxReq.__args_wrapper(
364
            ws_dir_path, method_type="GET", response_type="text", request_url="unique_identifier", request_data={})
365
366
    @staticmethod
367
    def workspace_path(ws_dir_path: str) -> Dict:
368
        return MpxReq.__args_wrapper(
369
            ws_dir_path, method_type="GET", response_type="text", request_url="workspace_path", request_data={})
370
371
    @staticmethod
372
    def physical_pages(ws_dir_path: str) -> Dict:
373
        return MpxReq.__args_wrapper(
374
            ws_dir_path, method_type="GET", response_type="dict", request_url="physical_pages", request_data={})
375
376
    @staticmethod
377
    def file_groups(ws_dir_path: str) -> Dict:
378
        return MpxReq.__args_wrapper(
379
            ws_dir_path, method_type="GET", response_type="dict", request_url="file_groups", request_data={})
380
381
    @staticmethod
382
    def agents(ws_dir_path: str) -> Dict:
383
        return MpxReq.__args_wrapper(
384
            ws_dir_path, method_type="GET", response_type="class", request_url="agent", request_data={})
385
386
    @staticmethod
387
    def add_agent(ws_dir_path: str, agent_model: Dict) -> Dict:
388
        request_data = {"class": agent_model}
389
        return MpxReq.__args_wrapper(
390
            ws_dir_path, method_type="POST", response_type="class", request_url="agent", request_data=request_data)
391
392
    @staticmethod
393
    def find_files(ws_dir_path: str, params: Dict) -> Dict:
394
        request_data = {"params": params}
395
        return MpxReq.__args_wrapper(
396
            ws_dir_path, method_type="GET", response_type="class", request_url="file", request_data=request_data)
397
398
    @staticmethod
399
    def add_file(ws_dir_path: str, data: Dict) -> Dict:
400
        request_data = {"form": data}
401
        return MpxReq.__args_wrapper(
402
            ws_dir_path, method_type="POST", response_type="class", request_url="file", request_data=request_data)
403
404
#
405
# Server
406
#
407
408
409
class OcrdMetsServer:
410
    def __init__(self, workspace, url):
411
        self.workspace = workspace
412
        self.url = url
413
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
414
        self.log = getLogger(f'ocrd.models.ocrd_mets.server.{self.url}')
415
416
    @staticmethod
417
    def create_process(mets_server_url: str, ws_dir_path: str, log_file: str) -> int:
418
        sub_process = Popen(
419
            args=["ocrd", "workspace", "-U", f"{mets_server_url}", "-d", f"{ws_dir_path}", "server", "start"],
420
            stdout=open(file=log_file, mode="w"), stderr=open(file=log_file, mode="a"), cwd=ws_dir_path,
421
            shell=False, universal_newlines=True, start_new_session=True
422
        )
423
        # Wait for the mets server to start
424
        sleep(2)
425
        if sub_process.poll():
426
            raise RuntimeError(f"Mets server starting failed. See {log_file} for errors")
427
        return sub_process.pid
428
429
    @staticmethod
430
    def kill_process(mets_server_pid: int):
431
        subprocess_run(args=["kill", "-s", "SIGINT", f"{mets_server_pid}"], shell=False, universal_newlines=True)
432
433
    def shutdown(self):
434
        if self.is_uds:
435
            if Path(self.url).exists():
436
                self.log.debug(f'UDS socket {self.url} still exists, removing it')
437
                Path(self.url).unlink()
438
        # os._exit because uvicorn catches SystemExit raised by sys.exit
439
        _exit(0)
440
441
    def startup(self):
442
        self.log.info("Starting up METS server")
443
444
        workspace = self.workspace
445
446
        app = FastAPI(
447
            title="OCR-D METS Server",
448
            description="Providing simultaneous write-access to mets.xml for OCR-D",
449
        )
450
451
        @app.exception_handler(ValidationError)
452
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
453
            return JSONResponse(status_code=400, content=exc.errors())
454
455
        @app.exception_handler(FileExistsError)
456
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
457
            return JSONResponse(status_code=400, content=str(exc))
458
459
        @app.exception_handler(re.error)
460
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
461
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
462
463
        @app.put(path='/')
464
        def save():
465
            """
466
            Write current changes to the file system
467
            """
468
            return workspace.save_mets()
469
470
        @app.delete(path='/')
471
        async def stop():
472
            """
473
            Stop the mets server
474
            """
475
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
476
            workspace.save_mets()
477
            self.shutdown()
478
479
        @app.post(path='/reload')
480
        async def workspace_reload_mets():
481
            """
482
            Reload mets file from the file system
483
            """
484
            workspace.reload_mets()
485
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
486
487
        @app.get(path='/unique_identifier', response_model=str)
488
        async def unique_identifier():
489
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
490
491
        @app.get(path='/workspace_path', response_model=str)
492
        async def workspace_path():
493
            return Response(content=workspace.directory, media_type="text/plain")
494
495
        @app.get(path='/physical_pages', response_model=OcrdPageListModel)
496
        async def physical_pages():
497
            return {'physical_pages': workspace.mets.physical_pages}
498
499
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
500
        async def file_groups():
501
            return {'file_groups': workspace.mets.file_groups}
502
503
        @app.get(path='/agent', response_model=OcrdAgentListModel)
504
        async def agents():
505
            return OcrdAgentListModel.create(workspace.mets.agents)
506
507
        @app.post(path='/agent', response_model=OcrdAgentModel)
508
        async def add_agent(agent: OcrdAgentModel):
509
            kwargs = agent.dict()
510
            kwargs['_type'] = kwargs.pop('type')
511
            workspace.mets.add_agent(**kwargs)
512
            return agent
513
514
        @app.get(path="/file", response_model=OcrdFileListModel)
515
        async def find_files(
516
            file_grp: Optional[str] = None,
517
            file_id: Optional[str] = None,
518
            page_id: Optional[str] = None,
519
            mimetype: Optional[str] = None,
520
            local_filename: Optional[str] = None,
521
            url: Optional[str] = None
522
        ):
523
            """
524
            Find files in the mets
525
            """
526
            found = workspace.mets.find_all_files(
527
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
528
            )
529
            return OcrdFileListModel.create(found)
530
531
        @app.post(path='/file', response_model=OcrdFileModel)
532
        async def add_file(
533
            file_grp: str = Form(),
534
            file_id: str = Form(),
535
            page_id: Optional[str] = Form(),
536
            mimetype: str = Form(),
537
            url: Optional[str] = Form(None),
538
            local_filename: Optional[str] = Form(None),
539
            force: bool = Form(False),
540
        ):
541
            """
542
            Add a file
543
            """
544
            # Validate
545
            file_resource = OcrdFileModel.create(
546
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
547
                local_filename=local_filename
548
            )
549
            # Add to workspace
550
            kwargs = file_resource.dict()
551
            workspace.add_file(**kwargs, force=force)
552
            return file_resource
553
554
        # ------------- #
555
556
        if self.is_uds:
557
            # Create socket and change to world-readable and -writable to avoid permission errors
558
            self.log.debug(f"chmod 0o677 {self.url}")
559
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
560
            if Path(self.url).exists() and not is_socket_in_use(self.url):
561
                # remove leftover unused socket which blocks startup
562
                Path(self.url).unlink()
563
            server.bind(self.url)  # creates the socket file
564
            atexit.register(self.shutdown)
565
            server.close()
566
            chmod(self.url, 0o666)
567
            uvicorn_kwargs = {'uds': self.url}
568
        else:
569
            parsed = urlparse(self.url)
570
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
571
        uvicorn_kwargs['log_config'] = None
572
        uvicorn_kwargs['access_log'] = False
573
574
        self.log.debug("Starting uvicorn")
575
        uvicorn.run(app, **uvicorn_kwargs)
576
577
578
def is_socket_in_use(socket_path):
579
    if Path(socket_path).exists():
580
        client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
581
        try:
582
            client.connect(socket_path)
583
        except OSError:
584
            return False
585
        client.close()
586
        return True
587