Passed
Pull Request — master (#1220)
by
unknown
02:53
created

ocrd.mets_server.ClientSideOcrdMets.find_files()   B

Complexity

Conditions 6

Size

Total Lines 29
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 29
rs 8.3706
c 0
b 0
f 0
cc 6
nop 2
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response, requests
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
25
#
26
# Models
27
#
28
29
30
class OcrdFileModel(BaseModel):
31
    file_grp: str = Field()
32
    file_id: str = Field()
33
    mimetype: str = Field()
34
    page_id: Optional[str] = Field()
35
    url: Optional[str] = Field()
36
    local_filename: Optional[str] = Field()
37
38
    @staticmethod
39
    def create(
40
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
41
        local_filename: Optional[Union[str, Path]], mimetype: str
42
    ):
43
        return OcrdFileModel(
44
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
45
            local_filename=str(local_filename)
46
        )
47
48
49
class OcrdAgentModel(BaseModel):
50
    name: str = Field()
51
    type: str = Field()
52
    role: str = Field()
53
    otherrole: Optional[str] = Field()
54
    othertype: str = Field()
55
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
56
57
    @staticmethod
58
    def create(
59
        name: str, _type: str, role: str, otherrole: str, othertype: str,
60
        notes: List[Tuple[Dict[str, str], Optional[str]]]
61
    ):
62
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
63
64
65
class OcrdFileListModel(BaseModel):
66
    files: List[OcrdFileModel] = Field()
67
68
    @staticmethod
69
    def create(files: List[OcrdFile]):
70
        ret = OcrdFileListModel(
71
            files=[
72
                OcrdFileModel.create(
73
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
74
                    local_filename=f.local_filename
75
                ) for f in files
76
            ]
77
        )
78
        return ret
79
80
81
class OcrdFileGroupListModel(BaseModel):
82
    file_groups: List[str] = Field()
83
84
    @staticmethod
85
    def create(file_groups: List[str]):
86
        return OcrdFileGroupListModel(file_groups=file_groups)
87
88
89
class OcrdAgentListModel(BaseModel):
90
    agents: List[OcrdAgentModel] = Field()
91
92
    @staticmethod
93
    def create(agents: List[OcrdAgent]):
94
        return OcrdAgentListModel(
95
            agents=[
96
                OcrdAgentModel.create(
97
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
98
                ) for a in agents
99
            ]
100
        )
101
102
103
#
104
# Client
105
#
106
107
108
class ClientSideOcrdMets:
109
    """
110
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
111
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
112
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
116
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
117
    """
118
119
    def __init__(self, url, workspace_path: Optional[str] = None):
120
        self.protocol = 'tcp' if url.startswith('http://') else 'uds'
121
        self.log = getLogger(f'ocrd.mets_client[{url}]')
122
        self.url = url if self.protocol == 'tcp' else f'http+unix://{url.replace("/", "%2F")}'
123
        self.ws_dir_path = workspace_path if workspace_path else None
124
125
        # Set if communication with the OcrdMetsServer happens over the ProcessingServer
126
        # The received root URL must be in the form: http://PS_host:PS_port/tcp_mets
127
        self.multiplexing_mode = True if self.protocol == 'tcp' and 'tcp_mets' in self.url else False
128
129
    @property
130
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
131
        return requests_session() if self.protocol == 'tcp' else requests_unixsocket_session()
132
133
    def __getattr__(self, name):
134
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
135
136
    def __str__(self):
137
        return f'<ClientSideOcrdMets[url={self.url}]>'
138
139
    def save(self):
140
        """
141
        Request writing the changes to the file system
142
        """
143
        if not self.multiplexing_mode:
144
            self.session.request(method='PUT', url=self.url)
145
            return
146
        request_body = {
147
            "workspace_path": self.ws_dir_path,
148
            "method_type": "PUT",
149
            "request_url": "",
150
            "request_data": {}
151
        }
152
        self.session.request(method="POST", url=self.url, json=request_body)
153
154
    def stop(self):
155
        """
156
        Request stopping the mets server
157
        """
158
        try:
159
            if not self.multiplexing_mode:
160
                self.session.request(method='DELETE', url=self.url)
161
                return
162
            request_body = {
163
                "workspace_path": self.ws_dir_path,
164
                "method_type": "DELETE",
165
                "request_url": "",
166
                "request_data": {}
167
            }
168
            self.session.request(method="POST", url=self.url, json=request_body)
169
        except ConnectionError:
170
            # Expected because we exit the process without returning
171
            pass
172
173
    def reload(self):
174
        """
175
        Request reloading of the mets file from the file system
176
        """
177
        if not self.multiplexing_mode:
178
            return self.session.request(method='POST', url=f'{self.url}/reload').text
179
        request_body = {
180
            "workspace_path": self.ws_dir_path,
181
            "method_type": "POST",
182
            "request_url": "reload",
183
            "request_data": {}
184
        }
185
        return self.session.request(method="POST", url=self.url, json=request_body).text
186
187
    @property
188
    def unique_identifier(self):
189
        if not self.multiplexing_mode:
190
            return self.session.request(method='GET', url=f'{self.url}/unique_identifier').text
191
        request_body = {
192
            "workspace_path": self.ws_dir_path,
193
            "method_type": "GET",
194
            "request_url": "unique_identifier",
195
            "request_data": {}
196
        }
197
        return self.session.request(method="POST", url=self.url, json=request_body).text
198
199
    @property
200
    def workspace_path(self):
201
        if self.ws_dir_path:
202
            return self.ws_dir_path
203
        if not self.multiplexing_mode:
204
            self.ws_dir_path = self.session.request(method='GET', url=f'{self.url}/workspace_path').text
205
            return self.ws_dir_path
206
        request_body = {
207
            "workspace_path": self.ws_dir_path,
208
            "method_type": "GET",
209
            "request_url": "workspace_path",
210
            "request_data": {}
211
        }
212
        self.ws_dir_path = self.session.request(method="POST", url=self.url, json=request_body).text
213
        return self.ws_dir_path
214
215
    @property
216
    def file_groups(self):
217
        if not self.multiplexing_mode:
218
            return self.session.request(method='GET', url=f'{self.url}/file_groups').json()['file_groups']
219
        request_body = {
220
            "workspace_path": self.ws_dir_path,
221
            "method_type": "GET",
222
            "request_url": "file_groups",
223
            "request_data": {}
224
        }
225
        return self.session.request(method="POST", url=self.url, json=request_body).json()['file_groups']
226
227
    @property
228
    def agents(self):
229
        if not self.multiplexing_mode:
230
            agent_dicts = self.session.request(method='GET', url=f'{self.url}/agent').json()['agents']
231
        else:
232
            request_body = {
233
                "workspace_path": self.ws_dir_path,
234
                "method_type": "GET",
235
                "request_url": "agent",
236
                "request_data": {}
237
            }
238
            agent_dicts = self.session.request(method="POST", url=self.url, json=request_body).json()['agents']
239
        for agent_dict in agent_dicts:
240
            agent_dict['_type'] = agent_dict.pop('type')
241
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
242
243
    def add_agent(self, *args, **kwargs):
244
        if not self.multiplexing_mode:
245
            return self.session.request(
246
                method='POST', url=f'{self.url}/agent', json=OcrdAgentModel.create(**kwargs).dict())
247
        request_body = {
248
            "workspace_path": self.ws_dir_path,
249
            "method_type": "POST",
250
            "request_url": "agent",
251
            "request_data": OcrdAgentModel.create(**kwargs).dict()
252
        }
253
        return self.session.request(method="POST", url=self.url, json=request_body)
254
255
    @deprecated_alias(ID="file_id")
256
    @deprecated_alias(pageId="page_id")
257
    @deprecated_alias(fileGrp="file_grp")
258
    def find_files(self, **kwargs):
259
        self.log.debug('find_files(%s)', kwargs)
260
        if 'pageId' in kwargs:
261
            kwargs['page_id'] = kwargs.pop('pageId')
262
        if 'ID' in kwargs:
263
            kwargs['file_id'] = kwargs.pop('ID')
264
        if 'fileGrp' in kwargs:
265
            kwargs['file_grp'] = kwargs.pop('fileGrp')
266
267
        if not self.multiplexing_mode:
268
            r = self.session.request(method='GET', url=f'{self.url}/file', params={**kwargs})
269
        else:
270
            request_body = {
271
                "workspace_path": self.ws_dir_path,
272
                "method_type": "GET",
273
                "request_url": "file",
274
                "request_data": {
275
                    "params": {**kwargs}
276
                }
277
            }
278
            r = self.session.request(method="POST", url=self.url, json=request_body)
279
280
        for f in r.json()['files']:
281
            yield ClientSideOcrdFile(
282
                None, ID=f['file_id'], pageId=f['page_id'], fileGrp=f['file_grp'], url=f['url'],
283
                local_filename=f['local_filename'], mimetype=f['mimetype']
284
            )
285
286
    def find_all_files(self, *args, **kwargs):
287
        return list(self.find_files(*args, **kwargs))
288
289
    @deprecated_alias(pageId="page_id")
290
    @deprecated_alias(ID="file_id")
291
    def add_file(
292
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
293
    ):
294
        data = OcrdFileModel.create(
295
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
296
            local_filename=local_filename
297
        )
298
299
        if not self.multiplexing_mode:
300
            r = self.session.request(method='POST', url=f'{self.url}/file', data=data.dict())
301
        else:
302
            request_body = {
303
                "workspace_path": self.ws_dir_path,
304
                "method_type": "POST",
305
                "request_url": "file",
306
                "request_data": {
307
                    "data": data.dict()
308
                }
309
            }
310
            r = self.session.request(method="POST", url=self.url, json=request_body)
311
        return ClientSideOcrdFile(
312
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
313
            local_filename=local_filename
314
        )
315
316
317
#
318
# Server
319
#
320
321
322
class OcrdMetsServer:
323
    def __init__(self, workspace, url):
324
        self.workspace = workspace
325
        self.url = url
326
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
327
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
328
329
    def shutdown(self):
330
        if self.is_uds:
331
            if Path(self.url).exists():
332
                self.log.warning(f'UDS socket {self.url} still exists, removing it')
333
                Path(self.url).unlink()
334
        # os._exit because uvicorn catches SystemExit raised by sys.exit
335
        _exit(0)
336
337
    def startup(self):
338
        self.log.info("Starting up METS server")
339
340
        workspace = self.workspace
341
342
        app = FastAPI(
343
            title="OCR-D METS Server",
344
            description="Providing simultaneous write-access to mets.xml for OCR-D",
345
        )
346
347
        @app.exception_handler(ValidationError)
348
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
349
            return JSONResponse(status_code=400, content=exc.errors())
350
351
        @app.exception_handler(FileExistsError)
352
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
353
            return JSONResponse(status_code=400, content=str(exc))
354
355
        @app.exception_handler(re.error)
356
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
357
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
358
359
        @app.put(path='/')
360
        def save():
361
            """
362
            Write current changes to the file system
363
            """
364
            return workspace.save_mets()
365
366
        @app.delete(path='/')
367
        async def stop():
368
            """
369
            Stop the mets server
370
            """
371
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
372
            workspace.save_mets()
373
            self.shutdown()
374
375
        @app.post(path='/reload')
376
        async def workspace_reload_mets():
377
            """
378
            Reload mets file from the file system
379
            """
380
            workspace.reload_mets()
381
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
382
383
        @app.get(path='/unique_identifier', response_model=str)
384
        async def unique_identifier():
385
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
386
387
        @app.get(path='/workspace_path', response_model=str)
388
        async def workspace_path():
389
            return Response(content=workspace.directory, media_type="text/plain")
390
391
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
392
        async def file_groups():
393
            return {'file_groups': workspace.mets.file_groups}
394
395
        @app.get(path='/agent', response_model=OcrdAgentListModel)
396
        async def agents():
397
            return OcrdAgentListModel.create(workspace.mets.agents)
398
399
        @app.post(path='/agent', response_model=OcrdAgentModel)
400
        async def add_agent(agent: OcrdAgentModel):
401
            kwargs = agent.dict()
402
            kwargs['_type'] = kwargs.pop('type')
403
            workspace.mets.add_agent(**kwargs)
404
            return agent
405
406
        @app.get(path="/file", response_model=OcrdFileListModel)
407
        async def find_files(
408
            file_grp: Optional[str] = None,
409
            file_id: Optional[str] = None,
410
            page_id: Optional[str] = None,
411
            mimetype: Optional[str] = None,
412
            local_filename: Optional[str] = None,
413
            url: Optional[str] = None
414
        ):
415
            """
416
            Find files in the mets
417
            """
418
            found = workspace.mets.find_all_files(
419
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
420
            )
421
            return OcrdFileListModel.create(found)
422
423
        @app.post(path='/file', response_model=OcrdFileModel)
424
        async def add_file(
425
            file_grp: str = Form(),
426
            file_id: str = Form(),
427
            page_id: Optional[str] = Form(),
428
            mimetype: str = Form(),
429
            url: Optional[str] = Form(None),
430
            local_filename: Optional[str] = Form(None)
431
        ):
432
            """
433
            Add a file
434
            """
435
            # Validate
436
            file_resource = OcrdFileModel.create(
437
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
438
                local_filename=local_filename
439
            )
440
            # Add to workspace
441
            kwargs = file_resource.dict()
442
            workspace.add_file(**kwargs)
443
            return file_resource
444
445
        # ------------- #
446
447
        if self.is_uds:
448
            # Create socket and change to world-readable and -writable to avoid permission errors
449
            self.log.debug(f"chmod 0o677 {self.url}")
450
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
451
            server.bind(self.url)  # creates the socket file
452
            atexit.register(self.shutdown)
453
            server.close()
454
            chmod(self.url, 0o666)
455
            uvicorn_kwargs = {'uds': self.url}
456
        else:
457
            parsed = urlparse(self.url)
458
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
459
460
        self.log.debug("Starting uvicorn")
461
        uvicorn.run(app, **uvicorn_kwargs)
462