1
|
|
|
from asyncio import iscoroutine, get_event_loop |
2
|
|
|
from datetime import datetime |
3
|
|
|
from fastapi import UploadFile |
4
|
|
|
from functools import wraps |
5
|
|
|
from hashlib import md5 |
6
|
|
|
from pika import URLParameters |
7
|
|
|
from pymongo import MongoClient, uri_parser as mongo_uri_parser |
8
|
|
|
from re import compile as re_compile, match as re_match, split as re_split, sub as re_sub |
9
|
|
|
from requests import get as requests_get, Session as Session_TCP |
10
|
|
|
from requests_unixsocket import Session as Session_UDS |
11
|
|
|
from time import sleep |
12
|
|
|
from typing import Dict, List |
13
|
|
|
from uuid import uuid4 |
14
|
|
|
from yaml import safe_load |
15
|
|
|
|
16
|
|
|
from ocrd.resolver import Resolver |
17
|
|
|
from ocrd.workspace import Workspace |
18
|
|
|
from ocrd_utils import generate_range, REGEX_PREFIX |
19
|
|
|
from ocrd_validators import ProcessingServerConfigValidator |
20
|
|
|
from .rabbitmq_utils import OcrdResultMessage, RMQPublisher |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
def call_sync(func): |
24
|
|
|
# Based on: https://gist.github.com/phizaz/20c36c6734878c6ec053245a477572ec |
25
|
|
|
@wraps(func) |
26
|
|
|
def func_wrapper(*args, **kwargs): |
27
|
|
|
result = func(*args, **kwargs) |
28
|
|
|
if iscoroutine(result): |
29
|
|
|
return get_event_loop().run_until_complete(result) |
30
|
|
|
return result |
31
|
|
|
return func_wrapper |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
def calculate_execution_time(start: datetime, end: datetime) -> int: |
35
|
|
|
""" |
36
|
|
|
Calculates the difference between `start` and `end` datetime. |
37
|
|
|
Returns the result in milliseconds |
38
|
|
|
""" |
39
|
|
|
return int((end - start).total_seconds() * 1000) |
40
|
|
|
|
41
|
|
|
|
42
|
|
|
def convert_url_to_uds_format(url: str) -> str: |
43
|
|
|
return f"http+unix://{url.replace('/', '%2F')}" |
44
|
|
|
|
45
|
|
|
|
46
|
|
|
def expand_page_ids(page_id: str) -> List: |
47
|
|
|
page_ids = [] |
48
|
|
|
if not page_id: |
49
|
|
|
return page_ids |
50
|
|
|
for page_id_token in re_split(pattern=r',', string=page_id): |
51
|
|
|
if page_id_token.startswith(REGEX_PREFIX): |
52
|
|
|
page_ids.append(re_compile(pattern=page_id_token[len(REGEX_PREFIX):])) |
53
|
|
|
elif '..' in page_id_token: |
54
|
|
|
page_ids += generate_range(*page_id_token.split(sep='..', maxsplit=1)) |
55
|
|
|
else: |
56
|
|
|
page_ids += [page_id_token] |
57
|
|
|
return page_ids |
58
|
|
|
|
59
|
|
|
|
60
|
|
|
def generate_created_time() -> int: |
61
|
|
|
return int(datetime.utcnow().timestamp()) |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
def generate_id() -> str: |
65
|
|
|
""" |
66
|
|
|
Generate the id to be used for processing job ids. |
67
|
|
|
Note, workspace_id and workflow_id in the reference |
68
|
|
|
WebAPI implementation are produced in the same manner |
69
|
|
|
""" |
70
|
|
|
return str(uuid4()) |
71
|
|
|
|
72
|
|
|
|
73
|
|
|
async def generate_workflow_content(workflow: UploadFile, encoding: str = "utf-8"): |
74
|
|
|
return (await workflow.read()).decode(encoding) |
75
|
|
|
|
76
|
|
|
|
77
|
|
|
def generate_workflow_content_hash(workflow_content: str, encoding: str = "utf-8"): |
78
|
|
|
return md5(workflow_content.encode(encoding)).hexdigest() |
79
|
|
|
|
80
|
|
|
|
81
|
|
|
def is_url_responsive(url: str, tries: int = 1, wait_time: int = 3) -> bool: |
82
|
|
|
while tries > 0: |
83
|
|
|
try: |
84
|
|
|
if requests_get(url).status_code == 200: |
85
|
|
|
return True |
86
|
|
|
except Exception: |
87
|
|
|
continue |
88
|
|
|
sleep(wait_time) |
89
|
|
|
tries -= 1 |
90
|
|
|
return False |
91
|
|
|
|
92
|
|
|
|
93
|
|
|
def validate_and_load_config(config_path: str) -> Dict: |
94
|
|
|
# Load and validate the config |
95
|
|
|
with open(config_path) as fin: |
96
|
|
|
config = safe_load(fin) |
97
|
|
|
report = ProcessingServerConfigValidator.validate(config) |
98
|
|
|
if not report.is_valid: |
99
|
|
|
raise Exception(f'Processing-Server configuration file is invalid:\n{report.errors}') |
100
|
|
|
return config |
101
|
|
|
|
102
|
|
|
|
103
|
|
|
def verify_database_uri(mongodb_address: str) -> str: |
104
|
|
|
try: |
105
|
|
|
# perform validation check |
106
|
|
|
mongo_uri_parser.parse_uri(uri=mongodb_address, validate=True) |
107
|
|
|
except Exception as error: |
108
|
|
|
raise ValueError(f"The MongoDB address '{mongodb_address}' is in wrong format, {error}") |
109
|
|
|
return mongodb_address |
110
|
|
|
|
111
|
|
|
|
112
|
|
|
def verify_and_parse_mq_uri(rabbitmq_address: str): |
113
|
|
|
""" |
114
|
|
|
Check the full list of available parameters in the docs here: |
115
|
|
|
https://pika.readthedocs.io/en/stable/_modules/pika/connection.html#URLParameters |
116
|
|
|
""" |
117
|
|
|
|
118
|
|
|
uri_pattern = r"^(?:([^:\/?#\s]+):\/{2})?(?:([^@\/?#\s]+)@)?([^\/?#\s]+)?(?:\/([^?#\s]*))?(?:[?]([^#\s]+))?\S*$" |
119
|
|
|
match = re_match(pattern=uri_pattern, string=rabbitmq_address) |
120
|
|
|
if not match: |
121
|
|
|
raise ValueError(f"The message queue server address is in wrong format: '{rabbitmq_address}'") |
122
|
|
|
url_params = URLParameters(rabbitmq_address) |
123
|
|
|
|
124
|
|
|
parsed_data = { |
125
|
|
|
'username': url_params.credentials.username, |
126
|
|
|
'password': url_params.credentials.password, |
127
|
|
|
'host': url_params.host, |
128
|
|
|
'port': url_params.port, |
129
|
|
|
'vhost': url_params.virtual_host |
130
|
|
|
} |
131
|
|
|
return parsed_data |
132
|
|
|
|
133
|
|
|
|
134
|
|
|
def verify_rabbitmq_available(host: str, port: int, vhost: str, username: str, password: str) -> None: |
135
|
|
|
""" |
136
|
|
|
# The protocol is intentionally set to HTTP instead of AMQP! |
137
|
|
|
if vhost != "/": |
138
|
|
|
vhost = f"/{vhost}" |
139
|
|
|
rabbitmq_test_url = f"http://{username}:{password}@{host}:{port}{vhost}" |
140
|
|
|
if is_url_responsive(url=rabbitmq_test_url, tries=3): |
141
|
|
|
return |
142
|
|
|
raise RuntimeError(f"Verifying connection has failed: {rabbitmq_test_url}") |
143
|
|
|
""" |
144
|
|
|
|
145
|
|
|
max_waiting_steps = 15 |
146
|
|
|
while max_waiting_steps > 0: |
147
|
|
|
try: |
148
|
|
|
dummy_publisher = RMQPublisher(host=host, port=port, vhost=vhost) |
149
|
|
|
dummy_publisher.authenticate_and_connect(username=username, password=password) |
150
|
|
|
except Exception: |
151
|
|
|
max_waiting_steps -= 1 |
152
|
|
|
sleep(2) |
153
|
|
|
else: |
154
|
|
|
# TODO: Disconnect the dummy_publisher here before returning... |
155
|
|
|
return |
156
|
|
|
raise RuntimeError(f'Cannot connect to RabbitMQ host: {host}, port: {port}, ' |
157
|
|
|
f'vhost: {vhost}, username: {username}') |
158
|
|
|
|
159
|
|
|
|
160
|
|
|
def verify_mongodb_available(mongo_url: str) -> None: |
161
|
|
|
""" |
162
|
|
|
# The protocol is intentionally set to HTTP instead of MONGODB! |
163
|
|
|
mongodb_test_url = mongo_url.replace("mongodb", "http") |
164
|
|
|
if is_url_responsive(url=mongodb_test_url, tries=3): |
165
|
|
|
return |
166
|
|
|
raise RuntimeError(f"Verifying connection has failed: {mongodb_test_url}") |
167
|
|
|
""" |
168
|
|
|
|
169
|
|
|
try: |
170
|
|
|
client = MongoClient(mongo_url, serverSelectionTimeoutMS=5000.0) |
171
|
|
|
client.admin.command("ismaster") |
172
|
|
|
except Exception: |
173
|
|
|
raise RuntimeError(f'Cannot connect to MongoDB: {re_sub(r":[^@]+@", ":****@", mongo_url)}') |
174
|
|
|
|
175
|
|
|
|
176
|
|
|
def download_ocrd_all_tool_json(ocrd_all_url: str): |
177
|
|
|
if not ocrd_all_url: |
178
|
|
|
raise ValueError(f'The URL of ocrd all tool json is empty') |
179
|
|
|
headers = {'Accept': 'application/json'} |
180
|
|
|
response = Session_TCP().get(ocrd_all_url, headers=headers) |
181
|
|
|
if not response.status_code == 200: |
182
|
|
|
raise ValueError(f"Failed to download ocrd all tool json from: '{ocrd_all_url}'") |
183
|
|
|
return response.json() |
184
|
|
|
|
185
|
|
|
|
186
|
|
|
def post_to_callback_url(logger, callback_url: str, result_message: OcrdResultMessage): |
187
|
|
|
logger.info(f'Posting result message to callback_url "{callback_url}"') |
188
|
|
|
headers = {"Content-Type": "application/json"} |
189
|
|
|
json_data = { |
190
|
|
|
"job_id": result_message.job_id, |
191
|
|
|
"state": result_message.state, |
192
|
|
|
"path_to_mets": result_message.path_to_mets, |
193
|
|
|
"workspace_id": result_message.workspace_id |
194
|
|
|
} |
195
|
|
|
response = Session_TCP().post(url=callback_url, headers=headers, json=json_data) |
196
|
|
|
logger.info(f'Response from callback_url "{response}"') |
197
|
|
|
|
198
|
|
|
|
199
|
|
|
def get_ocrd_workspace_instance(mets_path: str, mets_server_url: str = None) -> Workspace: |
200
|
|
|
if mets_server_url: |
201
|
|
|
if not is_mets_server_running(mets_server_url=mets_server_url): |
202
|
|
|
raise RuntimeError(f'The mets server is not running: {mets_server_url}') |
203
|
|
|
return Resolver().workspace_from_url(mets_url=mets_path, mets_server_url=mets_server_url) |
204
|
|
|
|
205
|
|
|
|
206
|
|
|
def get_ocrd_workspace_physical_pages(mets_path: str, mets_server_url: str = None) -> List[str]: |
207
|
|
|
return get_ocrd_workspace_instance(mets_path=mets_path, mets_server_url=mets_server_url).mets.physical_pages |
208
|
|
|
|
209
|
|
|
|
210
|
|
|
def is_mets_server_running(mets_server_url: str) -> bool: |
211
|
|
|
protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" |
212
|
|
|
session = Session_TCP() if protocol == "tcp" else Session_UDS() |
213
|
|
|
if protocol == "uds": |
214
|
|
|
mets_server_url = convert_url_to_uds_format(mets_server_url) |
215
|
|
|
try: |
216
|
|
|
response = session.get(url=f"{mets_server_url}/workspace_path") |
217
|
|
|
except Exception: |
218
|
|
|
return False |
219
|
|
|
return response.status_code == 200 |
220
|
|
|
|
221
|
|
|
|
222
|
|
|
def stop_mets_server(mets_server_url: str) -> bool: |
223
|
|
|
protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" |
224
|
|
|
session = Session_TCP() if protocol == "tcp" else Session_UDS() |
225
|
|
|
if protocol == "uds": |
226
|
|
|
mets_server_url = convert_url_to_uds_format(mets_server_url) |
227
|
|
|
try: |
228
|
|
|
response = session.delete(url=f"{mets_server_url}/") |
229
|
|
|
except Exception: |
230
|
|
|
return False |
231
|
|
|
return response.status_code == 200 |
232
|
|
|
|