1
|
|
|
# Copyright (C) 2014-2021 Greenbone Networks GmbH |
2
|
|
|
# |
3
|
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later |
4
|
|
|
# |
5
|
|
|
# This program is free software: you can redistribute it and/or modify |
6
|
|
|
# it under the terms of the GNU Affero General Public License as |
7
|
|
|
# published by the Free Software Foundation, either version 3 of the |
8
|
|
|
# License, or (at your option) any later version. |
9
|
|
|
# |
10
|
|
|
# This program is distributed in the hope that it will be useful, |
11
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
12
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
13
|
|
|
# GNU Affero General Public License for more details. |
14
|
|
|
# |
15
|
|
|
# You should have received a copy of the GNU Affero General Public License |
16
|
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
17
|
|
|
|
18
|
|
|
""" |
19
|
|
|
Module for serving and streaming data |
20
|
|
|
""" |
21
|
|
|
|
22
|
|
|
import logging |
23
|
|
|
import socket |
24
|
|
|
import ssl |
25
|
|
|
import time |
26
|
|
|
import threading |
27
|
|
|
import socketserver |
28
|
|
|
|
29
|
|
|
from abc import ABC, abstractmethod |
30
|
|
|
from pathlib import Path |
31
|
|
|
from typing import Callable, Optional, Tuple, Union |
32
|
|
|
|
33
|
|
|
from ospd.errors import OspdError |
34
|
|
|
|
35
|
|
|
logger = logging.getLogger(__name__) |
36
|
|
|
|
37
|
|
|
DEFAULT_BUFSIZE = 1024 |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
class Stream: |
41
|
|
|
def __init__(self, sock: socket.socket, stream_timeout: int): |
42
|
|
|
self.socket = sock |
43
|
|
|
self.socket.settimeout(stream_timeout) |
44
|
|
|
|
45
|
|
|
def close(self): |
46
|
|
|
"""Close the stream""" |
47
|
|
|
try: |
48
|
|
|
self.socket.shutdown(socket.SHUT_RDWR) |
49
|
|
|
except OSError as e: |
50
|
|
|
logger.debug( |
51
|
|
|
"Ignoring error while shutting down the connection. %s", e |
52
|
|
|
) |
53
|
|
|
|
54
|
|
|
self.socket.close() |
55
|
|
|
|
56
|
|
|
def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes: |
57
|
|
|
"""Read at maximum bufsize data from the stream""" |
58
|
|
|
data = self.socket.recv(bufsize) |
59
|
|
|
|
60
|
|
|
if not data: |
61
|
|
|
logger.debug('Client closed the connection') |
62
|
|
|
|
63
|
|
|
return data |
64
|
|
|
|
65
|
|
|
def write(self, data: bytes) -> bool: |
66
|
|
|
"""Send data in chunks of DEFAULT_BUFSIZE to the client""" |
67
|
|
|
b_start = 0 |
68
|
|
|
b_end = DEFAULT_BUFSIZE |
69
|
|
|
ret_success = True |
70
|
|
|
|
71
|
|
|
while True: |
72
|
|
|
if b_end > len(data): |
73
|
|
|
try: |
74
|
|
|
self.socket.send(data[b_start:]) |
75
|
|
|
except (socket.error, BrokenPipeError) as e: |
76
|
|
|
logger.error("Error sending data to the client. %s", e) |
77
|
|
|
ret_success = False |
78
|
|
|
finally: |
79
|
|
|
return ret_success # pylint: disable=lost-exception |
80
|
|
|
|
81
|
|
|
try: |
82
|
|
|
b_sent = self.socket.send(data[b_start:b_end]) |
83
|
|
|
except (socket.error, BrokenPipeError) as e: |
84
|
|
|
logger.error("Error sending data to the client. %s", e) |
85
|
|
|
return False |
86
|
|
|
|
87
|
|
|
b_start = b_end |
88
|
|
|
b_end += b_sent |
89
|
|
|
|
90
|
|
|
return ret_success |
91
|
|
|
|
92
|
|
|
|
93
|
|
|
StreamCallbackType = Callable[[Stream], None] |
94
|
|
|
|
95
|
|
|
InetAddress = Tuple[str, int] |
96
|
|
|
|
97
|
|
|
|
98
|
|
|
def validate_cacert_file(cacert: str): |
99
|
|
|
"""Check if provided file is a valid CA Certificate""" |
100
|
|
|
try: |
101
|
|
|
context = ssl.create_default_context(cafile=cacert) |
102
|
|
|
except AttributeError: |
103
|
|
|
# Python version < 2.7.9 |
104
|
|
|
return |
105
|
|
|
except IOError: |
106
|
|
|
raise OspdError('CA Certificate not found') from None |
107
|
|
|
|
108
|
|
|
try: |
109
|
|
|
not_after = context.get_ca_certs()[0]['notAfter'] |
110
|
|
|
not_after = ssl.cert_time_to_seconds(not_after) |
111
|
|
|
not_before = context.get_ca_certs()[0]['notBefore'] |
112
|
|
|
not_before = ssl.cert_time_to_seconds(not_before) |
113
|
|
|
except (KeyError, IndexError): |
114
|
|
|
raise OspdError('CA Certificate is erroneous') from None |
115
|
|
|
|
116
|
|
|
now = int(time.time()) |
117
|
|
|
if not_after < now: |
118
|
|
|
raise OspdError('CA Certificate expired') |
119
|
|
|
|
120
|
|
|
if not_before > now: |
121
|
|
|
raise OspdError('CA Certificate not active yet') |
122
|
|
|
|
123
|
|
|
|
124
|
|
|
class RequestHandler(socketserver.BaseRequestHandler): |
125
|
|
|
"""Class to handle the request.""" |
126
|
|
|
|
127
|
|
|
def handle(self): |
128
|
|
|
self.server.handle_request(self.request, self.client_address) |
129
|
|
|
|
130
|
|
|
|
131
|
|
|
class BaseServer(ABC): |
132
|
|
|
def __init__(self, stream_timeout: int): |
133
|
|
|
self.server = None |
134
|
|
|
self.stream_timeout = stream_timeout |
135
|
|
|
|
136
|
|
|
@abstractmethod |
137
|
|
|
def start(self, stream_callback: StreamCallbackType): |
138
|
|
|
"""Starts a server with capabilities to handle multiple client |
139
|
|
|
connections simultaneously. |
140
|
|
|
If a new client connects the stream_callback is called with a Stream |
141
|
|
|
|
142
|
|
|
Arguments: |
143
|
|
|
stream_callback (function): Callback function to be called when |
144
|
|
|
a stream is ready |
145
|
|
|
""" |
146
|
|
|
|
147
|
|
|
def close(self): |
148
|
|
|
"""Shutdown the server""" |
149
|
|
|
self.server.shutdown() |
150
|
|
|
self.server.server_close() |
151
|
|
|
|
152
|
|
|
@abstractmethod |
153
|
|
|
def handle_request(self, request, client_address): |
154
|
|
|
"""Handle an incoming client request""" |
155
|
|
|
|
156
|
|
|
def _start_threading_server(self): |
157
|
|
|
server_thread = threading.Thread(target=self.server.serve_forever) |
158
|
|
|
server_thread.daemon = True |
159
|
|
|
server_thread.start() |
160
|
|
|
|
161
|
|
|
|
162
|
|
|
class SocketServerMixin: |
163
|
|
|
# Use daemon mode to circrumvent a memory leak |
164
|
|
|
# (reported at https://bugs.python.org/issue37193). |
165
|
|
|
# |
166
|
|
|
# Daemonic threads are killed immediately by the python interpreter without |
167
|
|
|
# waiting for until they are finished. |
168
|
|
|
# |
169
|
|
|
# Maybe block_on_close = True could work too. |
170
|
|
|
# In that case the interpreter waits for the threads to finish but doesn't |
171
|
|
|
# track them in the _threads list. |
172
|
|
|
daemon_threads = True |
173
|
|
|
|
174
|
|
|
def __init__(self, server: BaseServer, address: Union[str, InetAddress]): |
175
|
|
|
self.server = server |
176
|
|
|
super().__init__(address, RequestHandler, bind_and_activate=True) |
177
|
|
|
|
178
|
|
|
def handle_request(self, request, client_address): |
179
|
|
|
self.server.handle_request(request, client_address) |
180
|
|
|
|
181
|
|
|
|
182
|
|
|
class ThreadedUnixSocketServer( |
183
|
|
|
SocketServerMixin, |
184
|
|
|
socketserver.ThreadingUnixStreamServer, |
185
|
|
|
): |
186
|
|
|
pass |
187
|
|
|
|
188
|
|
|
|
189
|
|
|
class ThreadedTlsSocketServer( |
190
|
|
|
SocketServerMixin, |
191
|
|
|
socketserver.ThreadingTCPServer, |
192
|
|
|
): |
193
|
|
|
pass |
194
|
|
|
|
195
|
|
|
|
196
|
|
|
class UnixSocketServer(BaseServer): |
197
|
|
|
"""Server for accepting connections via a Unix domain socket""" |
198
|
|
|
|
199
|
|
|
def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int): |
200
|
|
|
super().__init__(stream_timeout) |
201
|
|
|
self.socket_path = Path(socket_path) |
202
|
|
|
self.socket_mode = int(socket_mode, 8) |
203
|
|
|
|
204
|
|
|
def _cleanup_socket(self): |
205
|
|
|
if self.socket_path.exists(): |
206
|
|
|
self.socket_path.unlink() |
207
|
|
|
|
208
|
|
|
def _create_parent_dirs(self): |
209
|
|
|
# create all parent directories for the socket path |
210
|
|
|
parent = self.socket_path.parent |
211
|
|
|
parent.mkdir(parents=True, exist_ok=True) |
212
|
|
|
|
213
|
|
|
def start(self, stream_callback: StreamCallbackType): |
214
|
|
|
self._cleanup_socket() |
215
|
|
|
self._create_parent_dirs() |
216
|
|
|
|
217
|
|
|
try: |
218
|
|
|
self.stream_callback = stream_callback |
219
|
|
|
self.server = ThreadedUnixSocketServer(self, str(self.socket_path)) |
220
|
|
|
self._start_threading_server() |
221
|
|
|
except OSError as e: |
222
|
|
|
logger.error("Couldn't bind socket on %s", str(self.socket_path)) |
223
|
|
|
raise OspdError( |
224
|
|
|
"Couldn't bind socket on {}. {}".format( |
225
|
|
|
str(self.socket_path), e |
226
|
|
|
) |
227
|
|
|
) from e |
228
|
|
|
|
229
|
|
|
if self.socket_path.exists(): |
230
|
|
|
self.socket_path.chmod(self.socket_mode) |
231
|
|
|
|
232
|
|
|
def close(self): |
233
|
|
|
super().close() |
234
|
|
|
self._cleanup_socket() |
235
|
|
|
|
236
|
|
|
def handle_request(self, request, client_address): |
237
|
|
|
logger.debug("New request from %s", str(self.socket_path)) |
238
|
|
|
|
239
|
|
|
stream = Stream(request, self.stream_timeout) |
240
|
|
|
self.stream_callback(stream) |
241
|
|
|
|
242
|
|
|
|
243
|
|
|
class TlsServer(BaseServer): |
244
|
|
|
"""Server for accepting TLS encrypted connections via a TCP socket""" |
245
|
|
|
|
246
|
|
|
def __init__( |
247
|
|
|
self, |
248
|
|
|
address: str, |
249
|
|
|
port: int, |
250
|
|
|
cert_file: str, |
251
|
|
|
key_file: str, |
252
|
|
|
ca_file: str, |
253
|
|
|
stream_timeout: int, |
254
|
|
|
): |
255
|
|
|
super().__init__(stream_timeout) |
256
|
|
|
self.socket = (address, port) |
257
|
|
|
|
258
|
|
|
if not Path(cert_file).exists(): |
259
|
|
|
raise OspdError('cert file {} not found'.format(cert_file)) |
260
|
|
|
|
261
|
|
|
if not Path(key_file).exists(): |
262
|
|
|
raise OspdError('key file {} not found'.format(key_file)) |
263
|
|
|
|
264
|
|
|
if not Path(ca_file).exists(): |
265
|
|
|
raise OspdError('CA file {} not found'.format(ca_file)) |
266
|
|
|
|
267
|
|
|
validate_cacert_file(ca_file) |
268
|
|
|
|
269
|
|
|
protocol = ssl.PROTOCOL_SSLv23 |
270
|
|
|
self.tls_context = ssl.SSLContext(protocol) |
271
|
|
|
self.tls_context.verify_mode = ssl.CERT_REQUIRED |
272
|
|
|
|
273
|
|
|
self.tls_context.load_cert_chain(cert_file, keyfile=key_file) |
274
|
|
|
self.tls_context.load_verify_locations(ca_file) |
275
|
|
|
|
276
|
|
|
def start(self, stream_callback: StreamCallbackType): |
277
|
|
|
try: |
278
|
|
|
self.stream_callback = stream_callback |
279
|
|
|
self.server = ThreadedTlsSocketServer(self, self.socket) |
280
|
|
|
self._start_threading_server() |
281
|
|
|
except OSError as e: |
282
|
|
|
logger.error( |
283
|
|
|
"Couldn't bind socket on %s:%s", self.socket[0], self.socket[1] |
284
|
|
|
) |
285
|
|
|
raise OspdError( |
286
|
|
|
"Couldn't bind socket on {}:{}. {}".format( |
287
|
|
|
self.socket[0], str(self.socket[1]), e |
288
|
|
|
) |
289
|
|
|
) from e |
290
|
|
|
|
291
|
|
|
def handle_request(self, request, client_address): |
292
|
|
|
logger.debug("New connection from %s", client_address) |
293
|
|
|
|
294
|
|
|
req_socket = self.tls_context.wrap_socket(request, server_side=True) |
295
|
|
|
|
296
|
|
|
stream = Stream(req_socket, self.stream_timeout) |
297
|
|
|
self.stream_callback(stream) |
298
|
|
|
|