Passed
Pull Request — master (#1220)
by
unknown
03:46
created

ocrd.mets_server.OcrdFileModel.create()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 6
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from time import sleep
8
from pathlib import Path
9
from subprocess import Popen, run as subprocess_run
10
from urllib.parse import urlparse
11
import socket
12
import atexit
13
14
from fastapi import FastAPI, Request, Form, Response
15
from fastapi.responses import JSONResponse
16
from requests import Session as requests_session
17
from requests.exceptions import ConnectionError
18
from requests_unixsocket import Session as requests_unixsocket_session
19
from pydantic import BaseModel, Field, ValidationError
20
21
import uvicorn
22
23
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
24
from ocrd_utils import getLogger, deprecated_alias
25
26
27
#
28
# Models
29
#
30
31
32
class OcrdFileModel(BaseModel):
33
    file_grp: str = Field()
34
    file_id: str = Field()
35
    mimetype: str = Field()
36
    page_id: Optional[str] = Field()
37
    url: Optional[str] = Field()
38
    local_filename: Optional[str] = Field()
39
40
    @staticmethod
41
    def create(
42
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
43
        local_filename: Optional[Union[str, Path]], mimetype: str
44
    ):
45
        return OcrdFileModel(
46
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
47
            local_filename=str(local_filename)
48
        )
49
50
51
class OcrdAgentModel(BaseModel):
52
    name: str = Field()
53
    type: str = Field()
54
    role: str = Field()
55
    otherrole: Optional[str] = Field()
56
    othertype: str = Field()
57
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
58
59
    @staticmethod
60
    def create(
61
        name: str, _type: str, role: str, otherrole: str, othertype: str,
62
        notes: List[Tuple[Dict[str, str], Optional[str]]]
63
    ):
64
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
65
66
67
class OcrdFileListModel(BaseModel):
68
    files: List[OcrdFileModel] = Field()
69
70
    @staticmethod
71
    def create(files: List[OcrdFile]):
72
        ret = OcrdFileListModel(
73
            files=[
74
                OcrdFileModel.create(
75
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
76
                    local_filename=f.local_filename
77
                ) for f in files
78
            ]
79
        )
80
        return ret
81
82
83
class OcrdFileGroupListModel(BaseModel):
84
    file_groups: List[str] = Field()
85
86
    @staticmethod
87
    def create(file_groups: List[str]):
88
        return OcrdFileGroupListModel(file_groups=file_groups)
89
90
91
class OcrdAgentListModel(BaseModel):
92
    agents: List[OcrdAgentModel] = Field()
93
94
    @staticmethod
95
    def create(agents: List[OcrdAgent]):
96
        return OcrdAgentListModel(
97
            agents=[
98
                OcrdAgentModel.create(
99
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
100
                ) for a in agents
101
            ]
102
        )
103
104
105
#
106
# Client
107
#
108
109
110
class ClientSideOcrdMets:
111
    """
112
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
116
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
117
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
118
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
119
    """
120
121
    def __init__(self, url, workspace_path: Optional[str] = None):
122
        self.protocol = "tcp" if url.startswith("http://") else "uds"
123
        self.log = getLogger(f"ocrd.mets_client[{url}]")
124
        self.url = url if self.protocol == "tcp" else f'http+unix://{url.replace("/", "%2F")}'
125
        self.ws_dir_path = workspace_path if workspace_path else None
126
127
        if self.protocol == "tcp" and "tcp_mets" in self.url:
128
            self.multiplexing_mode = True
129
            if not self.ws_dir_path:
130
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
131
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
132
        else:
133
            self.multiplexing_mode = False
134
135
    @property
136
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
137
        return requests_session() if self.protocol == "tcp" else requests_unixsocket_session()
138
139
    def __getattr__(self, name):
140
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
141
142
    def __str__(self):
143
        return f"<ClientSideOcrdMets[url={self.url}]>"
144
145
    def save(self):
146
        """
147
        Request writing the changes to the file system
148
        """
149
        if not self.multiplexing_mode:
150
            self.session.request("PUT", url=self.url)
151
        else:
152
            self.session.request(
153
                "POST",
154
                self.url,
155
                json=MpxReq.save(self.ws_dir_path)
156
            )
157
158
    def stop(self):
159
        """
160
        Request stopping the mets server
161
        """
162
        try:
163
            if not self.multiplexing_mode:
164
                self.session.request("DELETE", self.url)
165
                return
166
            else:
167
                self.session.request(
168
                    "POST",
169
                    self.url,
170
                    json=MpxReq.stop(self.ws_dir_path)
171
                )
172
        except ConnectionError:
173
            # Expected because we exit the process without returning
174
            pass
175
176
    def reload(self):
177
        """
178
        Request reloading of the mets file from the file system
179
        """
180
        if not self.multiplexing_mode:
181
            return self.session.request("POST", f"{self.url}/reload").text
182
        else:
183
            return self.session.request(
184
                "POST",
185
                self.url,
186
                json=MpxReq.reload(self.ws_dir_path)
187
            ).json()["text"]
188
189
    @property
190
    def unique_identifier(self):
191
        if not self.multiplexing_mode:
192
            return self.session.request("GET", f"{self.url}/unique_identifier").text
193
        else:
194
            return self.session.request(
195
                "POST",
196
                self.url,
197
                json=MpxReq.unique_identifier(self.ws_dir_path)
198
            ).json()["text"]
199
200
    @property
201
    def workspace_path(self):
202
        if not self.multiplexing_mode:
203
            self.ws_dir_path = self.session.request("GET", f"{self.url}/workspace_path").text
204
            return self.ws_dir_path
205
        else:
206
            self.ws_dir_path = self.session.request(
207
                "POST",
208
                self.url,
209
                json=MpxReq.workspace_path(self.ws_dir_path)
210
            ).json()["text"]
211
            return self.ws_dir_path
212
213
    @property
214
    def file_groups(self):
215
        if not self.multiplexing_mode:
216
            return self.session.request("GET", f"{self.url}/file_groups").json()["file_groups"]
217
        else:
218
            return self.session.request(
219
                "POST",
220
                self.url,
221
                json=MpxReq.file_groups(self.ws_dir_path)
222
            ).json()["file_groups"]
223
224
    @property
225
    def agents(self):
226
        if not self.multiplexing_mode:
227
            agent_dicts = self.session.request("GET", f"{self.url}/agent").json()["agents"]
228
        else:
229
            agent_dicts = self.session.request(
230
                "POST",
231
                self.url,
232
                json=MpxReq.agents(self.ws_dir_path)
233
            ).json()["agents"]
234
235
        for agent_dict in agent_dicts:
236
            agent_dict["_type"] = agent_dict.pop("type")
237
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
238
239
    def add_agent(self, *args, **kwargs):
240
        if not self.multiplexing_mode:
241
            return self.session.request("POST", f"{self.url}/agent", json=OcrdAgentModel.create(**kwargs).dict())
242
        else:
243
            self.session.request(
244
                "POST",
245
                self.url,
246
                json=MpxReq.add_agent(self.ws_dir_path, OcrdAgentModel.create(**kwargs).dict())
247
            ).json()
248
            return OcrdAgentModel.create(**kwargs)
249
250
    @deprecated_alias(ID="file_id")
251
    @deprecated_alias(pageId="page_id")
252
    @deprecated_alias(fileGrp="file_grp")
253
    def find_files(self, **kwargs):
254
        self.log.debug("find_files(%s)", kwargs)
255
        if "pageId" in kwargs:
256
            kwargs["page_id"] = kwargs.pop("pageId")
257
        if "ID" in kwargs:
258
            kwargs["file_id"] = kwargs.pop("ID")
259
        if "fileGrp" in kwargs:
260
            kwargs["file_grp"] = kwargs.pop("fileGrp")
261
262
        if not self.multiplexing_mode:
263
            r = self.session.request(method="GET", url=f"{self.url}/file", params={**kwargs})
264
        else:
265
            r = self.session.request(
266
                "POST",
267
                self.url,
268
                json=MpxReq.find_files(self.ws_dir_path, {**kwargs})
269
            )
270
271
        for f in r.json()["files"]:
272
            yield ClientSideOcrdFile(
273
                None, ID=f["file_id"], pageId=f["page_id"], fileGrp=f["file_grp"], url=f["url"],
274
                local_filename=f["local_filename"], mimetype=f["mimetype"]
275
            )
276
277
    def find_all_files(self, *args, **kwargs):
278
        return list(self.find_files(*args, **kwargs))
279
280
    @deprecated_alias(pageId="page_id")
281
    @deprecated_alias(ID="file_id")
282
    def add_file(
283
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
284
    ):
285
        data = OcrdFileModel.create(
286
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
287
            local_filename=local_filename
288
        )
289
290
        if not self.multiplexing_mode:
291
            r = self.session.request("POST", f"{self.url}/file", data=data.dict())
292
            if not r:
293
                raise RuntimeError("Add file failed. Please check provided parameters")
294
        else:
295
            r = self.session.request("POST", self.url, json=MpxReq.add_file(self.ws_dir_path, data.dict()))
296
            if "error" in r:
297
                raise RuntimeError(f"Add file failed: Msg: {r['error']}")
298
299
        return ClientSideOcrdFile(
300
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
301
            local_filename=local_filename
302
        )
303
304
305
class MpxReq:
306
    """This class wrapps the request bodies needed for the tcp forwarding
307
308
    For every mets-server-call like find_files or workspace_path a special request_body is
309
    needed to call `MetsServerProxy.forward_tcp_request`. These are created by this functions.
310
311
    Reason to put this to a separate class is to allow easier testing
312
    """
313
314
    @staticmethod
315
    def __args_wrapper(
316
        workspace_path: str, method_type: str, response_type: str, request_url: str, request_data: dict
317
    ) -> Dict:
318
        return {
319
            "workspace_path": workspace_path,
320
            "method_type": method_type,
321
            "response_type": response_type,
322
            "request_url": request_url,
323
            "request_data": request_data
324
        }
325
326
    @staticmethod
327
    def save(ws_dir_path: str) -> Dict:
328
        return MpxReq.__args_wrapper(
329
            ws_dir_path, method_type="PUT", response_type="empty", request_url="", request_data={})
330
331
    @staticmethod
332
    def stop(ws_dir_path: str) -> Dict:
333
        return MpxReq.__args_wrapper(
334
            ws_dir_path, method_type="DELETE", response_type="empty", request_url="", request_data={})
335
336
    @staticmethod
337
    def reload(ws_dir_path: str) -> Dict:
338
        return MpxReq.__args_wrapper(
339
            ws_dir_path, method_type="POST", response_type="text", request_url="reload", request_data={})
340
341
    @staticmethod
342
    def unique_identifier(ws_dir_path: str) -> Dict:
343
        return MpxReq.__args_wrapper(
344
            ws_dir_path, method_type="GET", response_type="text", request_url="unique_identifier", request_data={})
345
346
    @staticmethod
347
    def workspace_path(ws_dir_path: str) -> Dict:
348
        return MpxReq.__args_wrapper(
349
            ws_dir_path, method_type="GET", response_type="text", request_url="workspace_path", request_data={})
350
351
    @staticmethod
352
    def file_groups(ws_dir_path: str) -> Dict:
353
        return MpxReq.__args_wrapper(
354
            ws_dir_path, method_type="GET", response_type="dict", request_url="file_groups", request_data={})
355
356
    @staticmethod
357
    def agents(ws_dir_path: str) -> Dict:
358
        return MpxReq.__args_wrapper(
359
            ws_dir_path, method_type="GET", response_type="class", request_url="agent", request_data={})
360
361
    @staticmethod
362
    def add_agent(ws_dir_path: str, agent_model: Dict) -> Dict:
363
        request_data = {"class": agent_model}
364
        return MpxReq.__args_wrapper(
365
            ws_dir_path, method_type="POST", response_type="class", request_url="agent", request_data=request_data)
366
367
    @staticmethod
368
    def find_files(ws_dir_path: str, params: Dict) -> Dict:
369
        request_data = {"params": params}
370
        return MpxReq.__args_wrapper(
371
            ws_dir_path, method_type="GET", response_type="class", request_url="file", request_data=request_data)
372
373
    @staticmethod
374
    def add_file(ws_dir_path: str, data: Dict) -> Dict:
375
        request_data = {"form": data}
376
        return MpxReq.__args_wrapper(
377
            ws_dir_path, method_type="POST", response_type="class", request_url="file", request_data=request_data)
378
379
#
380
# Server
381
#
382
383
384
class OcrdMetsServer:
385
    def __init__(self, workspace, url):
386
        self.workspace = workspace
387
        self.url = url
388
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
389
        self.log = getLogger(f'ocrd.models.ocrd_mets.server.{self.url}')
390
391
    @staticmethod
392
    def create_process(mets_server_url: str, ws_dir_path: str, log_file: str) -> int:
393
        sub_process = Popen(
394
            args=["ocrd", "workspace", "-U", f"{mets_server_url}", "-d", f"{ws_dir_path}", "server", "start"],
395
            stdout=open(file=log_file, mode="w"), stderr=open(file=log_file, mode="a"), cwd=ws_dir_path,
396
            shell=False, universal_newlines=True, start_new_session=True
397
        )
398
        # Wait for the mets server to start
399
        sleep(2)
400
        if sub_process.poll():
401
            raise RuntimeError(f"Mets server starting failed. See {log_file} for errors")
402
        return sub_process.pid
403
404
    @staticmethod
405
    def kill_process(mets_server_pid: int):
406
        subprocess_run(args=["kill", "-s", "SIGINT", f"{mets_server_pid}"], shell=False, universal_newlines=True)
407
        return
408
409
    def shutdown(self):
410
        if self.is_uds:
411
            if Path(self.url).exists():
412
                self.log.debug(f'UDS socket {self.url} still exists, removing it')
413
                Path(self.url).unlink()
414
        # os._exit because uvicorn catches SystemExit raised by sys.exit
415
        _exit(0)
416
417
    def startup(self):
418
        self.log.info("Starting up METS server")
419
420
        workspace = self.workspace
421
422
        app = FastAPI(
423
            title="OCR-D METS Server",
424
            description="Providing simultaneous write-access to mets.xml for OCR-D",
425
        )
426
427
        @app.exception_handler(ValidationError)
428
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
429
            return JSONResponse(status_code=400, content=exc.errors())
430
431
        @app.exception_handler(FileExistsError)
432
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
433
            return JSONResponse(status_code=400, content=str(exc))
434
435
        @app.exception_handler(re.error)
436
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
437
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
438
439
        @app.put(path='/')
440
        def save():
441
            """
442
            Write current changes to the file system
443
            """
444
            return workspace.save_mets()
445
446
        @app.delete(path='/')
447
        async def stop():
448
            """
449
            Stop the mets server
450
            """
451
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
452
            workspace.save_mets()
453
            self.shutdown()
454
455
        @app.post(path='/reload')
456
        async def workspace_reload_mets():
457
            """
458
            Reload mets file from the file system
459
            """
460
            workspace.reload_mets()
461
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
462
463
        @app.get(path='/unique_identifier', response_model=str)
464
        async def unique_identifier():
465
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
466
467
        @app.get(path='/workspace_path', response_model=str)
468
        async def workspace_path():
469
            return Response(content=workspace.directory, media_type="text/plain")
470
471
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
472
        async def file_groups():
473
            return {'file_groups': workspace.mets.file_groups}
474
475
        @app.get(path='/agent', response_model=OcrdAgentListModel)
476
        async def agents():
477
            return OcrdAgentListModel.create(workspace.mets.agents)
478
479
        @app.post(path='/agent', response_model=OcrdAgentModel)
480
        async def add_agent(agent: OcrdAgentModel):
481
            kwargs = agent.dict()
482
            kwargs['_type'] = kwargs.pop('type')
483
            workspace.mets.add_agent(**kwargs)
484
            return agent
485
486
        @app.get(path="/file", response_model=OcrdFileListModel)
487
        async def find_files(
488
            file_grp: Optional[str] = None,
489
            file_id: Optional[str] = None,
490
            page_id: Optional[str] = None,
491
            mimetype: Optional[str] = None,
492
            local_filename: Optional[str] = None,
493
            url: Optional[str] = None
494
        ):
495
            """
496
            Find files in the mets
497
            """
498
            found = workspace.mets.find_all_files(
499
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
500
            )
501
            return OcrdFileListModel.create(found)
502
503
        @app.post(path='/file', response_model=OcrdFileModel)
504
        async def add_file(
505
            file_grp: str = Form(),
506
            file_id: str = Form(),
507
            page_id: Optional[str] = Form(),
508
            mimetype: str = Form(),
509
            url: Optional[str] = Form(None),
510
            local_filename: Optional[str] = Form(None)
511
        ):
512
            """
513
            Add a file
514
            """
515
            # Validate
516
            file_resource = OcrdFileModel.create(
517
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
518
                local_filename=local_filename
519
            )
520
            # Add to workspace
521
            kwargs = file_resource.dict()
522
            workspace.add_file(**kwargs)
523
            return file_resource
524
525
        # ------------- #
526
527
        if self.is_uds:
528
            # Create socket and change to world-readable and -writable to avoid permission errors
529
            self.log.debug(f"chmod 0o677 {self.url}")
530
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
531
            server.bind(self.url)  # creates the socket file
532
            atexit.register(self.shutdown)
533
            server.close()
534
            chmod(self.url, 0o666)
535
            uvicorn_kwargs = {'uds': self.url}
536
        else:
537
            parsed = urlparse(self.url)
538
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
539
        uvicorn_kwargs['log_config'] = None
540
        uvicorn_kwargs['access_log'] = False
541
542
        self.log.debug("Starting uvicorn")
543
        uvicorn.run(app, **uvicorn_kwargs)
544