Passed
Push — master ( abc48f...62dd14 )
by Konstantin
03:02
created

ocrd.mets_server.ClientSideOcrdMets.reload()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""
2
# METS server functionality
3
"""
4
import re
5
from os import environ, _exit, chmod
6
from io import BytesIO
7
from typing import Any, Dict, Optional, Union, List, Tuple
8
from pathlib import Path
9
from urllib.parse import urlparse
10
import socket
11
12
from fastapi import FastAPI, Request, File, Form, Response
13
from fastapi.responses import JSONResponse
14
from requests import request, Session as requests_session
15
from requests.exceptions import ConnectionError
16
from requests_unixsocket import Session as requests_unixsocket_session
17
from pydantic import BaseModel, Field, ValidationError
18
19
import uvicorn
20
21
from ocrd_models import OcrdMets, OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
22
from ocrd_utils import initLogging, getLogger, deprecated_alias
23
24
#
25
# Models
26
#
27
28
class OcrdFileModel(BaseModel):
29
    file_grp : str = Field()
30
    file_id : str = Field()
31
    mimetype : str = Field()
32
    page_id : Optional[str] = Field()
33
    url : Optional[str] = Field()
34
    local_filename : Optional[str] = Field()
35
36
    @staticmethod
37
    def create(file_grp : str, file_id : str, page_id : Optional[str], url : Optional[str], local_filename : Optional[Union[str, Path]], mimetype : str):
38
        return OcrdFileModel(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=str(local_filename))
39
40
class OcrdAgentModel(BaseModel):
41
    name : str = Field()
42
    type : str = Field()
43
    role : str = Field()
44
    otherrole : Optional[str] = Field()
45
    othertype : str = Field()
46
    notes : Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
47
48
    @staticmethod
49
    def create(name : str, _type : str, role : str, otherrole : str, othertype : str, notes : List[Tuple[Dict[str, str], Optional[str]]]):
50
        return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
51
52
53
class OcrdFileListModel(BaseModel):
54
    files : List[OcrdFileModel] = Field()
55
56
    @staticmethod
57
    def create(files : List[OcrdFile]):
58
        ret = OcrdFileListModel(
59
            files=[OcrdFileModel.create(
60
                file_grp=f.fileGrp,
61
                file_id=f.ID,
62
                mimetype=f.mimetype,
63
                page_id=f.pageId,
64
                url=f.url,
65
                local_filename=f.local_filename
66
            ) for f in files])
67
        return ret
68
69
class OcrdFileGroupListModel(BaseModel):
70
    file_groups : List[str] = Field()
71
72
    @staticmethod
73
    def create(file_groups : List[str]):
74
        return OcrdFileGroupListModel(file_groups=file_groups)
75
76
class OcrdAgentListModel(BaseModel):
77
    agents : List[OcrdAgentModel] = Field()
78
79
    @staticmethod
80
    def create(agents : List[OcrdAgent]):
81
        return OcrdAgentListModel(
82
            agents=[OcrdAgentModel.create(name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes) for a in agents]
83
        )
84
85
#
86
# Client
87
#
88
89
90
class ClientSideOcrdMets():
91
    """
92
    Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
93
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
94
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
95
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
96
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
97
    :py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
98
    :py:class:`ocrd.mets_server.OcrdMetsServer`.
99
    """
100
101
    def __init__(self, url):
102
        protocol = 'tcp' if url.startswith('http://') else 'uds'
103
        self.log = getLogger(f'ocrd.mets_client[{url}]')
104
        self.url = url if protocol == 'tcp' else f'http+unix://{url.replace("/", "%2F")}'
105
        self.session = requests_session() if protocol == 'tcp' else requests_unixsocket_session()
106
107
    def __getattr__(self, name):
108
        raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
109
110
    def __str__(self):
111
        return f'<ClientSideOcrdMets[url={self.url}]>'
112
113
    @property
114
    def workspace_path(self):
115
        return self.session.request('GET', f'{self.url}/workspace_path').text
116
117
    def reload(self):
118
        return self.session.request('POST', f'{self.url}/reload').text
119
120
    @deprecated_alias(ID="file_id")
121
    @deprecated_alias(pageId="page_id")
122
    @deprecated_alias(fileGrp="file_grp")
123
    def find_files(self, **kwargs):
124
        self.log.debug('find_files(%s)', kwargs)
125
        if 'pageId' in kwargs:
126
            kwargs['page_id'] = kwargs.pop('pageId')
127
        if 'ID' in kwargs:
128
            kwargs['file_id'] = kwargs.pop('ID')
129
        if 'fileGrp' in kwargs:
130
            kwargs['file_grp'] = kwargs.pop('fileGrp')
131
        r = self.session.request('GET', f'{self.url}/file', params={**kwargs})
132
        for f in r.json()['files']:
133
            yield ClientSideOcrdFile(None, ID=f['file_id'], pageId=f['page_id'], fileGrp=f['file_grp'], url=f['url'], local_filename=f['local_filename'], mimetype=f['mimetype'])
134
135
    def find_all_files(self, *args, **kwargs):
136
        return list(self.find_files(*args, **kwargs))
137
138
    def add_agent(self, *args, **kwargs):
139
        return self.session.request('POST', f'{self.url}/agent', json=OcrdAgentModel.create(**kwargs).dict())
140
141
    @property
142
    def agents(self):
143
        agent_dicts = self.session.request('GET', f'{self.url}/agent').json()['agents']
144
        for agent_dict in agent_dicts:
145
            agent_dict['_type'] = agent_dict.pop('type')
146
        return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
147
148
    @property
149
    def unique_identifier(self):
150
        return self.session.request('GET', f'{self.url}/unique_identifier').text
151
152
    @property
153
    def file_groups(self):
154
        return self.session.request('GET', f'{self.url}/file_groups').json()['file_groups']
155
156
    @deprecated_alias(pageId="page_id")
157
    @deprecated_alias(ID="file_id")
158
    def add_file(self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs):
159
        data = OcrdFileModel.create(
160
            file_id=file_id,
161
            file_grp=file_grp,
162
            page_id=page_id,
163
            mimetype=mimetype,
164
            url=url,
165
            local_filename=local_filename)
166
        r = self.session.request('POST', f'{self.url}/file', data=data.dict())
167
        return ClientSideOcrdFile(
168
                None,
169
                ID=file_id,
170
                fileGrp=file_grp,
171
                url=url,
172
                pageId=page_id,
173
                mimetype=mimetype,
174
                local_filename=local_filename)
175
176
177
    def save(self):
178
        self.session.request('PUT', self.url)
179
180
    def stop(self):
181
        try:
182
            self.session.request('DELETE', self.url)
183
        except ConnectionError:
184
            # Expected because we exit the process without returning
185
            pass
186
187
#
188
# Server
189
#
190
191
class OcrdMetsServer():
192
193
    def __init__(self, workspace, url):
194
        self.workspace = workspace
195
        self.url = url
196
        self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
197
        self.log = getLogger(f'ocrd.mets_server[{self.url}]')
198
199
    def shutdown(self):
200
        self.log.info("Shutting down METS server")
201
        if self.is_uds:
202
            Path(self.url).unlink()
203
        # os._exit because uvicorn catches SystemExit raised by sys.exit
204
        _exit(0)
205
206
    def startup(self):
207
        self.log.info("Starting up METS server")
208
209
        workspace = self.workspace
210
211
        app = FastAPI(
212
            title="OCR-D METS Server",
213
            description="Providing simultaneous write-access to mets.xml for OCR-D",
214
        )
215
216
        @app.exception_handler(ValidationError)
217
        async def exception_handler_validation_error(request: Request, exc: ValidationError):
218
            return JSONResponse(status_code=400, content=exc.errors())
219
220
        @app.exception_handler(FileExistsError)
221
        async def exception_handler_file_exists(request: Request, exc: FileExistsError):
222
            return JSONResponse(status_code=400, content=str(exc))
223
224
        @app.exception_handler(re.error)
225
        async def exception_handler_invalid_regex(request: Request, exc: re.error):
226
            return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
227
228
        @app.get("/file", response_model=OcrdFileListModel)
229
        async def find_files(
230
            file_grp : Optional[str] = None,
231
            file_id : Optional[str] = None,
232
            page_id : Optional[str] = None,
233
            mimetype : Optional[str] = None,
234
            local_filename : Optional[str] = None,
235
            url : Optional[str] = None,
236
        ):
237
            """
238
            Find files in the mets
239
            """
240
            found = workspace.mets.find_all_files(fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url)
241
            return OcrdFileListModel.create(found)
242
243
        @app.put('/')
244
        def save():
245
            return workspace.save_mets()
246
247
        @app.post('/file', response_model=OcrdFileModel)
248
        async def add_file(
249
            file_grp : str = Form(),
250
            file_id : str = Form(),
251
            page_id : Optional[str] = Form(),
252
            mimetype : str = Form(),
253
            url : Optional[str] = Form(None),
254
            local_filename : Optional[str] = Form(None),
255
        ):
256
            """
257
            Add a file
258
            """
259
            # Validate
260
            file_resource = OcrdFileModel.create(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=local_filename)
261
            # Add to workspace
262
            kwargs = file_resource.dict()
263
            workspace.add_file(**kwargs)
264
            return file_resource
265
266
        @app.get('/file_groups', response_model=OcrdFileGroupListModel)
267
        async def file_groups():
268
            return {'file_groups': workspace.mets.file_groups}
269
270
        @app.post('/agent', response_model=OcrdAgentModel)
271
        async def add_agent(agent : OcrdAgentModel):
272
            kwargs = agent.dict()
273
            kwargs['_type'] = kwargs.pop('type')
274
            workspace.mets.add_agent(**kwargs)
275
            return agent
276
277
        @app.get('/agent', response_model=OcrdAgentListModel)
278
        async def agents():
279
            return OcrdAgentListModel.create(workspace.mets.agents)
280
281
        @app.get('/unique_identifier', response_model=str)
282
        async def unique_identifier():
283
            return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
284
285
        @app.get('/workspace_path', response_model=str)
286
        async def workspace_path():
287
            return Response(content=workspace.directory, media_type="text/plain")
288
289
        @app.post('/reload')
290
        async def workspace_reload_mets():
291
            workspace.reload_mets()
292
            return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
293
294
        @app.delete('/')
295
        async def stop():
296
            """
297
            Stop the server
298
            """
299
            getLogger('ocrd.models.ocrd_mets').info('Shutting down')
300
            workspace.save_mets()
301
            self.shutdown()
302
303
        # ------------- #
304
305
        if self.is_uds:
306
            # Create socket and change to world-readable and -writable to avoid
307
            # permsission errors
308
            self.log.debug(f"chmod 0o677 {self.url}")
309
            server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
310
            server.bind(self.url)  # creates the socket file
311
            server.close()
312
            chmod(self.url, 0o666)
313
            uvicorn_kwargs = {'uds': self.url}
314
        else:
315
            parsed = urlparse(self.url)
316
            uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
317
318
        self.log.debug("Starting uvicorn")
319
        uvicorn.run(app, **uvicorn_kwargs)
320