Passed
Pull Request — master (#1220)
by
unknown
07:04
created

ocrd.mets_server   B

Complexity

Total Complexity 50

Size/Duplication

Total Lines 468
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 50
eloc 314
dl 0
loc 468
rs 8.4
c 0
b 0
f 0

23 Methods

Rating   Name   Duplication   Size   Complexity  
A OcrdFileModel.create() 0 8 1
A OcrdFileGroupListModel.create() 0 3 1
A OcrdAgentModel.create() 0 6 1
A OcrdAgentListModel.create() 0 7 1
A OcrdFileListModel.create() 0 11 1
A ClientSideOcrdMets.unique_identifier() 0 11 2
A ClientSideOcrdMets.workspace_path() 0 13 2
A ClientSideOcrdMets.session() 0 3 2
A ClientSideOcrdMets.add_agent() 0 11 2
A ClientSideOcrdMets.reload() 0 13 2
A ClientSideOcrdMets.__getattr__() 0 2 1
B ClientSideOcrdMets.find_files() 0 29 6
A ClientSideOcrdMets.stop() 0 18 3
A ClientSideOcrdMets.save() 0 14 2
B ClientSideOcrdMets.__init__() 0 19 8
A ClientSideOcrdMets.agents() 0 15 3
A ClientSideOcrdMets.find_all_files() 0 2 1
A ClientSideOcrdMets.file_groups() 0 11 2
A ClientSideOcrdMets.__str__() 0 2 1
A OcrdMetsServer.shutdown() 0 7 3
A ClientSideOcrdMets.add_file() 0 23 2
A OcrdMetsServer.__init__() 0 5 1
B OcrdMetsServer.startup() 0 125 2

How to fix   Complexity   

Complexity

Complex classes like ocrd.mets_server often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import _exit, chmod
6
from typing import Dict, Optional, Union, List, Tuple
7
from pathlib import Path
8
from urllib.parse import urlparse
9
import socket
10
import atexit
11
12
from fastapi import FastAPI, Request, Form, Response, requests
13
from fastapi.responses import JSONResponse
14
from requests import Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import getLogger, deprecated_alias
23
24
25
#
26
# Models
27
#
28
29
30
class OcrdFileModel(BaseModel):
31
    file_grp: str = Field()
32
    file_id: str = Field()
33
    mimetype: str = Field()
34
    page_id: Optional[str] = Field()
35
    url: Optional[str] = Field()
36
    local_filename: Optional[str] = Field()
37
38
    @staticmethod
39
    def create(
40
        file_grp: str, file_id: str, page_id: Optional[str], url: Optional[str],
41
        local_filename: Optional[Union[str, Path]], mimetype: str
42
    ):
43
        return OcrdFileModel(
44
            file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
45
            local_filename=str(local_filename)
46
        )
47
48
49
class OcrdAgentModel(BaseModel):
50
    name: str = Field()
51
    type: str = Field()
52
    role: str = Field()
53
    otherrole: Optional[str] = Field()
54
    othertype: str = Field()
55
    notes: Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
56
57
    @staticmethod
58
    def create(
59
        name: str, _type: str, role: str, otherrole: str, othertype: str,
60
        notes: List[Tuple[Dict[str, str], Optional[str]]]
61
    ):
62
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
63
64
65
class OcrdFileListModel(BaseModel):
66
    files: List[OcrdFileModel] = Field()
67
68
    @staticmethod
69
    def create(files: List[OcrdFile]):
70
        ret = OcrdFileListModel(
71
            files=[
72
                OcrdFileModel.create(
73
                    file_grp=f.fileGrp, file_id=f.ID, mimetype=f.mimetype, page_id=f.pageId, url=f.url,
74
                    local_filename=f.local_filename
75
                ) for f in files
76
            ]
77
        )
78
        return ret
79
80
81
class OcrdFileGroupListModel(BaseModel):
82
    file_groups: List[str] = Field()
83
84
    @staticmethod
85
    def create(file_groups: List[str]):
86
        return OcrdFileGroupListModel(file_groups=file_groups)
87
88
89
class OcrdAgentListModel(BaseModel):
90
    agents: List[OcrdAgentModel] = Field()
91
92
    @staticmethod
93
    def create(agents: List[OcrdAgent]):
94
        return OcrdAgentListModel(
95
            agents=[
96
                OcrdAgentModel.create(
97
                    name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes
98
                ) for a in agents
99
            ]
100
        )
101
102
103
#
104
# Client
105
#
106
107
108
class ClientSideOcrdMets:
109
    """
110
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
111
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
112
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
113
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
114
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
115
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
116
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
117
    """
118
119
    def __init__(self, url, workspace_path: Optional[str] = None):
120
        self.protocol = 'tcp' if url.startswith('http://') else 'uds'
121
        self.log = getLogger(f'ocrd.mets_client[{url}]')
122
        self.url = url if self.protocol == 'tcp' else f'http+unix://{url.replace("/", "%2F")}'
123
        self.ws_dir_path = workspace_path if workspace_path else None
124
125
        # TODO: Replace the `tcp_mets` constant with a variable that is imported from the ProcessingServer
126
        # Set if communication with the OcrdMetsServer happens over the ProcessingServer
127
        # The received root URL must be in the form: http://PS_host:PS_port/tcp_mets
128
        self.multiplexing_mode = False
129
        self.ps_proxy_url = None
130
131
        if self.protocol == 'tcp' and 'tcp_mets' in self.url:
132
            self.multiplexing_mode = True
133
            self.ps_proxy_url = url
134
        if self.multiplexing_mode:
135
            if not self.ws_dir_path:
136
                # Must be set since this path is the way to multiplex among multiple workspaces on the PS side
137
                raise ValueError("ClientSideOcrdMets runs in multiplexing mode but the workspace dir path is not set!")
138
139
    @property
140
    def session(self) -> Union[requests_session, requests_unixsocket_session]:
141
        return requests_session() if self.protocol == 'tcp' else requests_unixsocket_session()
142
143
    def __getattr__(self, name):
144
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
145
146
    def __str__(self):
147
        return f'<ClientSideOcrdMets[url={self.url}]>'
148
149
    def save(self):
150
        """
151
        Request writing the changes to the file system
152
        """
153
        if not self.multiplexing_mode:
154
            self.session.request(method='PUT', url=self.url)
155
            return
156
        request_body = {
157
            "workspace_path": self.ws_dir_path,
158
            "method_type": "PUT",
159
            "request_url": "",
160
            "request_data": {}
161
        }
162
        self.session.request(method="POST", url=self.ps_proxy_url, json=request_body)
163
164
    def stop(self):
165
        """
166
        Request stopping the mets server
167
        """
168
        try:
169
            if not self.multiplexing_mode:
170
                self.session.request(method='DELETE', url=self.url)
171
                return
172
            request_body = {
173
                "workspace_path": self.ws_dir_path,
174
                "method_type": "DELETE",
175
                "request_url": "",
176
                "request_data": {}
177
            }
178
            self.session.request(method="POST", url=self.ps_proxy_url, json=request_body)
179
        except ConnectionError:
180
            # Expected because we exit the process without returning
181
            pass
182
183
    def reload(self):
184
        """
185
        Request reloading of the mets file from the file system
186
        """
187
        if not self.multiplexing_mode:
188
            return self.session.request(method='POST', url=f'{self.url}/reload').text
189
        request_body = {
190
            "workspace_path": self.ws_dir_path,
191
            "method_type": "POST",
192
            "request_url": "reload",
193
            "request_data": {}
194
        }
195
        return self.session.request(method="POST", url=self.ps_proxy_url, json=request_body).text
196
197
    @property
198
    def unique_identifier(self):
199
        if not self.multiplexing_mode:
200
            return self.session.request(method='GET', url=f'{self.url}/unique_identifier').text
201
        request_body = {
202
            "workspace_path": self.ws_dir_path,
203
            "method_type": "GET",
204
            "request_url": "unique_identifier",
205
            "request_data": {}
206
        }
207
        return self.session.request(method="POST", url=self.ps_proxy_url, json=request_body).text
208
209
    @property
210
    def workspace_path(self):
211
        if not self.multiplexing_mode:
212
            self.ws_dir_path = self.session.request(method='GET', url=f'{self.url}/workspace_path').text
213
            return self.ws_dir_path
214
        request_body = {
215
            "workspace_path": self.ws_dir_path,
216
            "method_type": "GET",
217
            "request_url": "workspace_path",
218
            "request_data": {}
219
        }
220
        self.ws_dir_path = self.session.request(method="POST", url=self.ps_proxy_url, json=request_body).text
221
        return self.ws_dir_path
222
223
    @property
224
    def file_groups(self):
225
        if not self.multiplexing_mode:
226
            return self.session.request(method='GET', url=f'{self.url}/file_groups').json()['file_groups']
227
        request_body = {
228
            "workspace_path": self.ws_dir_path,
229
            "method_type": "GET",
230
            "request_url": "file_groups",
231
            "request_data": {}
232
        }
233
        return self.session.request(method="POST", url=self.ps_proxy_url, json=request_body).json()['file_groups']
234
235
    @property
236
    def agents(self):
237
        if not self.multiplexing_mode:
238
            agent_dicts = self.session.request(method='GET', url=f'{self.url}/agent').json()['agents']
239
        else:
240
            request_body = {
241
                "workspace_path": self.ws_dir_path,
242
                "method_type": "GET",
243
                "request_url": "agent",
244
                "request_data": {}
245
            }
246
            agent_dicts = self.session.request(method="POST", url=self.ps_proxy_url, json=request_body).json()['agents']
247
        for agent_dict in agent_dicts:
248
            agent_dict['_type'] = agent_dict.pop('type')
249
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
250
251
    def add_agent(self, *args, **kwargs):
252
        if not self.multiplexing_mode:
253
            return self.session.request(
254
                method='POST', url=f'{self.url}/agent', json=OcrdAgentModel.create(**kwargs).dict())
255
        request_body = {
256
            "workspace_path": self.ws_dir_path,
257
            "method_type": "POST",
258
            "request_url": "agent",
259
            "request_data": OcrdAgentModel.create(**kwargs).dict()
260
        }
261
        return self.session.request(method="POST", url=self.ps_proxy_url, json=request_body)
262
263
    @deprecated_alias(ID="file_id")
264
    @deprecated_alias(pageId="page_id")
265
    @deprecated_alias(fileGrp="file_grp")
266
    def find_files(self, **kwargs):
267
        self.log.debug('find_files(%s)', kwargs)
268
        if 'pageId' in kwargs:
269
            kwargs['page_id'] = kwargs.pop('pageId')
270
        if 'ID' in kwargs:
271
            kwargs['file_id'] = kwargs.pop('ID')
272
        if 'fileGrp' in kwargs:
273
            kwargs['file_grp'] = kwargs.pop('fileGrp')
274
275
        if not self.multiplexing_mode:
276
            r = self.session.request(method='GET', url=f'{self.url}/file', params={**kwargs})
277
        else:
278
            request_body = {
279
                "workspace_path": self.ws_dir_path,
280
                "method_type": "GET",
281
                "request_url": "file",
282
                "request_data": {
283
                    "params": {**kwargs}
284
                }
285
            }
286
            r = self.session.request(method="POST", url=self.ps_proxy_url, json=request_body)
287
288
        for f in r.json()['files']:
289
            yield ClientSideOcrdFile(
290
                None, ID=f['file_id'], pageId=f['page_id'], fileGrp=f['file_grp'], url=f['url'],
291
                local_filename=f['local_filename'], mimetype=f['mimetype']
292
            )
293
294
    def find_all_files(self, *args, **kwargs):
295
        return list(self.find_files(*args, **kwargs))
296
297
    @deprecated_alias(pageId="page_id")
298
    @deprecated_alias(ID="file_id")
299
    def add_file(
300
        self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs
301
    ):
302
        data = OcrdFileModel.create(
303
            file_id=file_id, file_grp=file_grp, page_id=page_id, mimetype=mimetype, url=url,
304
            local_filename=local_filename
305
        )
306
307
        if not self.multiplexing_mode:
308
            r = self.session.request(method='POST', url=f'{self.url}/file', data=data.dict())
309
        else:
310
            request_body = {
311
                "workspace_path": self.ws_dir_path,
312
                "method_type": "POST",
313
                "request_url": "file",
314
                "request_data": data.dict()
315
            }
316
            r = self.session.request(method="POST", url=self.ps_proxy_url, json=request_body)
317
        return ClientSideOcrdFile(
318
            None, ID=file_id, fileGrp=file_grp, url=url, pageId=page_id, mimetype=mimetype,
319
            local_filename=local_filename
320
        )
321
322
323
#
324
# Server
325
#
326
327
328
class OcrdMetsServer:
329
    def __init__(self, workspace, url):
330
        self.workspace = workspace
331
        self.url = url
332
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
333
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
334
335
    def shutdown(self):
336
        if self.is_uds:
337
            if Path(self.url).exists():
338
                self.log.warning(f'UDS socket {self.url} still exists, removing it')
339
                Path(self.url).unlink()
340
        # os._exit because uvicorn catches SystemExit raised by sys.exit
341
        _exit(0)
342
343
    def startup(self):
344
        self.log.info("Starting up METS server")
345
346
        workspace = self.workspace
347
348
        app = FastAPI(
349
            title="OCR-D METS Server",
350
            description="Providing simultaneous write-access to mets.xml for OCR-D",
351
        )
352
353
        @app.exception_handler(ValidationError)
354
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
355
            return JSONResponse(status_code=400, content=exc.errors())
356
357
        @app.exception_handler(FileExistsError)
358
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
359
            return JSONResponse(status_code=400, content=str(exc))
360
361
        @app.exception_handler(re.error)
362
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
363
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
364
365
        @app.put(path='/')
366
        def save():
367
            """
368
            Write current changes to the file system
369
            """
370
            return workspace.save_mets()
371
372
        @app.delete(path='/')
373
        async def stop():
374
            """
375
            Stop the mets server
376
            """
377
            getLogger('ocrd.models.ocrd_mets').info(f'Shutting down METS Server {self.url}')
378
            workspace.save_mets()
379
            self.shutdown()
380
381
        @app.post(path='/reload')
382
        async def workspace_reload_mets():
383
            """
384
            Reload mets file from the file system
385
            """
386
            workspace.reload_mets()
387
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
388
389
        @app.get(path='/unique_identifier', response_model=str)
390
        async def unique_identifier():
391
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
392
393
        @app.get(path='/workspace_path', response_model=str)
394
        async def workspace_path():
395
            return Response(content=workspace.directory, media_type="text/plain")
396
397
        @app.get(path='/file_groups', response_model=OcrdFileGroupListModel)
398
        async def file_groups():
399
            return {'file_groups': workspace.mets.file_groups}
400
401
        @app.get(path='/agent', response_model=OcrdAgentListModel)
402
        async def agents():
403
            return OcrdAgentListModel.create(workspace.mets.agents)
404
405
        @app.post(path='/agent', response_model=OcrdAgentModel)
406
        async def add_agent(agent: OcrdAgentModel):
407
            kwargs = agent.dict()
408
            kwargs['_type'] = kwargs.pop('type')
409
            workspace.mets.add_agent(**kwargs)
410
            return agent
411
412
        @app.get(path="/file", response_model=OcrdFileListModel)
413
        async def find_files(
414
            file_grp: Optional[str] = None,
415
            file_id: Optional[str] = None,
416
            page_id: Optional[str] = None,
417
            mimetype: Optional[str] = None,
418
            local_filename: Optional[str] = None,
419
            url: Optional[str] = None
420
        ):
421
            """
422
            Find files in the mets
423
            """
424
            found = workspace.mets.find_all_files(
425
                fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url
426
            )
427
            return OcrdFileListModel.create(found)
428
429
        @app.post(path='/file', response_model=OcrdFileModel)
430
        async def add_file(
431
            file_grp: str = Form(),
432
            file_id: str = Form(),
433
            page_id: Optional[str] = Form(),
434
            mimetype: str = Form(),
435
            url: Optional[str] = Form(None),
436
            local_filename: Optional[str] = Form(None)
437
        ):
438
            """
439
            Add a file
440
            """
441
            # Validate
442
            file_resource = OcrdFileModel.create(
443
                file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url,
444
                local_filename=local_filename
445
            )
446
            # Add to workspace
447
            kwargs = file_resource.dict()
448
            workspace.add_file(**kwargs)
449
            return file_resource
450
451
        # ------------- #
452
453
        if self.is_uds:
454
            # Create socket and change to world-readable and -writable to avoid permission errors
455
            self.log.debug(f"chmod 0o677 {self.url}")
456
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
457
            server.bind(self.url)  # creates the socket file
458
            atexit.register(self.shutdown)
459
            server.close()
460
            chmod(self.url, 0o666)
461
            uvicorn_kwargs = {'uds': self.url}
462
        else:
463
            parsed = urlparse(self.url)
464
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
465
466
        self.log.debug("Starting uvicorn")
467
        uvicorn.run(app, **uvicorn_kwargs)
468