1
|
|
|
from asyncio import iscoroutine, get_event_loop |
2
|
|
|
from datetime import datetime |
3
|
|
|
from fastapi import UploadFile |
4
|
|
|
from functools import wraps |
5
|
|
|
from hashlib import md5 |
6
|
|
|
from json import loads |
7
|
|
|
from pathlib import Path |
8
|
|
|
from re import compile as re_compile, split as re_split |
9
|
|
|
from requests import get as requests_get, Session as Session_TCP |
10
|
|
|
from requests_unixsocket import Session as Session_UDS |
11
|
|
|
from time import sleep |
12
|
|
|
from typing import List |
13
|
|
|
from uuid import uuid4 |
14
|
|
|
|
15
|
|
|
from ocrd.resolver import Resolver |
16
|
|
|
from ocrd.workspace import Workspace |
17
|
|
|
from ocrd.mets_server import MpxReq |
18
|
|
|
from ocrd_utils import config, generate_range, REGEX_PREFIX, safe_filename, getLogger, resource_string |
19
|
|
|
from .constants import OCRD_ALL_TOOL_JSON, OCRD_ALL_TOOL_JSON_URL |
20
|
|
|
from .rabbitmq_utils import OcrdResultMessage |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
def call_sync(func): |
24
|
|
|
# Based on: https://gist.github.com/phizaz/20c36c6734878c6ec053245a477572ec |
25
|
|
|
@wraps(func) |
26
|
|
|
def func_wrapper(*args, **kwargs): |
27
|
|
|
result = func(*args, **kwargs) |
28
|
|
|
if iscoroutine(result): |
29
|
|
|
return get_event_loop().run_until_complete(result) |
30
|
|
|
return result |
31
|
|
|
return func_wrapper |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
def calculate_execution_time(start: datetime, end: datetime) -> int: |
35
|
|
|
""" |
36
|
|
|
Calculates the difference between 'start' and 'end' datetime. |
37
|
|
|
Returns the result in milliseconds |
38
|
|
|
""" |
39
|
|
|
return int((end - start).total_seconds() * 1000) |
40
|
|
|
|
41
|
|
|
|
42
|
|
|
def calculate_processing_request_timeout(amount_pages: int, timeout_per_page: float = 20.0) -> float: |
43
|
|
|
return amount_pages * timeout_per_page |
44
|
|
|
|
45
|
|
|
|
46
|
|
|
def convert_url_to_uds_format(url: str) -> str: |
47
|
|
|
return f"http+unix://{url.replace('/', '%2F')}" |
48
|
|
|
|
49
|
|
|
|
50
|
|
|
def expand_page_ids(page_id: str) -> List: |
51
|
|
|
page_ids = [] |
52
|
|
|
if not page_id: |
53
|
|
|
return page_ids |
54
|
|
|
for page_id_token in re_split(pattern=r',', string=page_id): |
55
|
|
|
if page_id_token.startswith(REGEX_PREFIX): |
56
|
|
|
page_ids.append(re_compile(pattern=page_id_token[len(REGEX_PREFIX):])) |
57
|
|
|
elif '..' in page_id_token: |
58
|
|
|
page_ids += generate_range(*page_id_token.split(sep='..', maxsplit=1)) |
59
|
|
|
else: |
60
|
|
|
page_ids += [page_id_token] |
61
|
|
|
return page_ids |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
def generate_created_time() -> int: |
65
|
|
|
return int(datetime.utcnow().timestamp()) |
66
|
|
|
|
67
|
|
|
|
68
|
|
|
def generate_id() -> str: |
69
|
|
|
""" |
70
|
|
|
Generate the id to be used for processing job ids. |
71
|
|
|
Note, workspace_id and workflow_id in the reference |
72
|
|
|
WebAPI implementation are produced in the same manner |
73
|
|
|
""" |
74
|
|
|
return str(uuid4()) |
75
|
|
|
|
76
|
|
|
|
77
|
|
|
async def generate_workflow_content(workflow: UploadFile, encoding: str = "utf-8"): |
78
|
|
|
return (await workflow.read()).decode(encoding) |
79
|
|
|
|
80
|
|
|
|
81
|
|
|
def generate_workflow_content_hash(workflow_content: str, encoding: str = "utf-8"): |
82
|
|
|
return md5(workflow_content.encode(encoding)).hexdigest() |
83
|
|
|
|
84
|
|
|
|
85
|
|
|
def is_url_responsive(url: str, tries: int = 1, wait_time: int = 3) -> bool: |
86
|
|
|
while tries > 0: |
87
|
|
|
try: |
88
|
|
|
if requests_get(url).status_code == 200: |
89
|
|
|
return True |
90
|
|
|
except Exception: |
91
|
|
|
continue |
92
|
|
|
sleep(wait_time) |
93
|
|
|
tries -= 1 |
94
|
|
|
return False |
95
|
|
|
|
96
|
|
|
|
97
|
|
|
def load_ocrd_all_tool_json(download_if_missing: bool = True): |
98
|
|
|
try: |
99
|
|
|
ocrd_all_tool_json = loads(resource_string('ocrd', OCRD_ALL_TOOL_JSON)) |
100
|
|
|
except Exception as error: |
101
|
|
|
if not download_if_missing: |
102
|
|
|
raise Exception(error) |
103
|
|
|
response = Session_TCP().get(OCRD_ALL_TOOL_JSON_URL, headers={"Accept": "application/json"}) |
104
|
|
|
if not response.status_code == 200: |
105
|
|
|
raise ValueError(f"Failed to download ocrd all tool json from: '{OCRD_ALL_TOOL_JSON_URL}'") |
106
|
|
|
ocrd_all_tool_json = response.json() |
107
|
|
|
return ocrd_all_tool_json |
108
|
|
|
|
109
|
|
|
|
110
|
|
|
def post_to_callback_url(logger, callback_url: str, result_message: OcrdResultMessage): |
111
|
|
|
logger.info(f'Posting result message to callback_url "{callback_url}"') |
112
|
|
|
headers = {"Content-Type": "application/json"} |
113
|
|
|
json_data = { |
114
|
|
|
"job_id": result_message.job_id, |
115
|
|
|
"state": result_message.state, |
116
|
|
|
"path_to_mets": result_message.path_to_mets, |
117
|
|
|
"workspace_id": result_message.workspace_id |
118
|
|
|
} |
119
|
|
|
response = Session_TCP().post(url=callback_url, headers=headers, json=json_data) |
120
|
|
|
logger.info(f'Response from callback_url "{response}"') |
121
|
|
|
|
122
|
|
|
|
123
|
|
|
def get_ocrd_workspace_instance(mets_path: str, mets_server_url: str = None) -> Workspace: |
124
|
|
|
if mets_server_url: |
125
|
|
|
if not is_mets_server_running(mets_server_url=mets_server_url, ws_dir_path=str(Path(mets_path).parent)): |
126
|
|
|
raise RuntimeError(f'The mets server is not running: {mets_server_url}') |
127
|
|
|
return Resolver().workspace_from_url(mets_url=mets_path, mets_server_url=mets_server_url) |
128
|
|
|
|
129
|
|
|
|
130
|
|
|
def get_ocrd_workspace_physical_pages(mets_path: str, mets_server_url: str = None) -> List[str]: |
131
|
|
|
return get_ocrd_workspace_instance(mets_path=mets_path, mets_server_url=mets_server_url).mets.physical_pages |
132
|
|
|
|
133
|
|
|
|
134
|
|
|
def is_mets_server_running(mets_server_url: str, ws_dir_path: str = None) -> bool: |
135
|
|
|
protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" |
136
|
|
|
session = Session_TCP() if protocol == "tcp" else Session_UDS() |
137
|
|
|
if protocol == "uds": |
138
|
|
|
mets_server_url = convert_url_to_uds_format(mets_server_url) |
139
|
|
|
try: |
140
|
|
|
if 'tcp_mets' in mets_server_url: |
141
|
|
|
if not ws_dir_path: |
142
|
|
|
return False |
143
|
|
|
path = session.post( |
144
|
|
|
url=f"{mets_server_url}", |
145
|
|
|
json=MpxReq.workspace_path(ws_dir_path) |
146
|
|
|
).json()["text"] |
147
|
|
|
return bool(path) |
148
|
|
|
else: |
149
|
|
|
try: |
150
|
|
|
response = session.get(url=f"{mets_server_url}/workspace_path") |
151
|
|
|
return response.status_code == 200 |
152
|
|
|
except OSError: |
153
|
|
|
return False |
154
|
|
|
except Exception: |
155
|
|
|
getLogger("ocrd_network.utils").exception("Unexpected exception in is_mets_server_running: ") |
156
|
|
|
return False |
157
|
|
|
|
158
|
|
|
|
159
|
|
|
def stop_mets_server(mets_server_url: str, ws_dir_path: str = None) -> bool: |
160
|
|
|
protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" |
161
|
|
|
session = Session_TCP() if protocol == "tcp" else Session_UDS() |
162
|
|
|
if protocol == "uds": |
163
|
|
|
mets_server_url = convert_url_to_uds_format(mets_server_url) |
164
|
|
|
try: |
165
|
|
|
if 'tcp_mets' in mets_server_url: |
166
|
|
|
if not ws_dir_path: |
167
|
|
|
return False |
168
|
|
|
response = session.post(url=f"{mets_server_url}", json=MpxReq.stop(ws_dir_path)) |
169
|
|
|
else: |
170
|
|
|
response = session.delete(url=f"{mets_server_url}/") |
171
|
|
|
except Exception: |
172
|
|
|
return False |
173
|
|
|
return response.status_code == 200 |
174
|
|
|
|
175
|
|
|
|
176
|
|
|
def get_uds_path(ws_dir_path: str) -> Path: |
177
|
|
|
return Path(config.OCRD_NETWORK_SOCKETS_ROOT_DIR, f"{safe_filename(ws_dir_path)}.sock") |
178
|
|
|
|