|
1
|
|
|
# Copyright (C) 2014-2020 Greenbone Networks GmbH |
|
2
|
|
|
# |
|
3
|
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later |
|
4
|
|
|
# |
|
5
|
|
|
# This program is free software: you can redistribute it and/or modify |
|
6
|
|
|
# it under the terms of the GNU Affero General Public License as |
|
7
|
|
|
# published by the Free Software Foundation, either version 3 of the |
|
8
|
|
|
# License, or (at your option) any later version. |
|
9
|
|
|
# |
|
10
|
|
|
# This program is distributed in the hope that it will be useful, |
|
11
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
12
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
13
|
|
|
# GNU Affero General Public License for more details. |
|
14
|
|
|
# |
|
15
|
|
|
# You should have received a copy of the GNU Affero General Public License |
|
16
|
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
17
|
|
|
|
|
18
|
|
|
""" |
|
19
|
|
|
Module for serving and streaming data |
|
20
|
|
|
""" |
|
21
|
|
|
|
|
22
|
|
|
import logging |
|
23
|
|
|
import socket |
|
24
|
|
|
import ssl |
|
25
|
|
|
import time |
|
26
|
|
|
import threading |
|
27
|
|
|
import socketserver |
|
28
|
|
|
|
|
29
|
|
|
from abc import ABC, abstractmethod |
|
30
|
|
|
from pathlib import Path |
|
31
|
|
|
from typing import Callable, Optional, Tuple, Union |
|
32
|
|
|
|
|
33
|
|
|
from ospd.errors import OspdError |
|
34
|
|
|
|
|
35
|
|
|
logger = logging.getLogger(__name__) |
|
36
|
|
|
|
|
37
|
|
|
DEFAULT_BUFSIZE = 1024 |
|
38
|
|
|
|
|
39
|
|
|
|
|
40
|
|
|
class Stream: |
|
41
|
|
|
def __init__(self, sock: socket.socket, stream_timeout: int): |
|
42
|
|
|
self.socket = sock |
|
43
|
|
|
self.socket.settimeout(stream_timeout) |
|
44
|
|
|
|
|
45
|
|
|
def close(self): |
|
46
|
|
|
""" Close the stream |
|
47
|
|
|
""" |
|
48
|
|
|
try: |
|
49
|
|
|
self.socket.shutdown(socket.SHUT_RDWR) |
|
50
|
|
|
except OSError as e: |
|
51
|
|
|
logger.debug( |
|
52
|
|
|
"Ignoring error while shutting down the connection. %s", e |
|
53
|
|
|
) |
|
54
|
|
|
|
|
55
|
|
|
self.socket.close() |
|
56
|
|
|
|
|
57
|
|
|
def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes: |
|
58
|
|
|
""" Read at maximum bufsize data from the stream |
|
59
|
|
|
""" |
|
60
|
|
|
data = self.socket.recv(bufsize) |
|
61
|
|
|
|
|
62
|
|
|
if not data: |
|
63
|
|
|
logger.debug('Client closed the connection') |
|
64
|
|
|
|
|
65
|
|
|
return data |
|
66
|
|
|
|
|
67
|
|
|
def write(self, data: bytes) -> bool: |
|
68
|
|
|
""" Send data in chunks of DEFAULT_BUFSIZE to the client |
|
69
|
|
|
""" |
|
70
|
|
|
b_start = 0 |
|
71
|
|
|
b_end = DEFAULT_BUFSIZE |
|
72
|
|
|
ret_success = True |
|
73
|
|
|
|
|
74
|
|
|
while True: |
|
75
|
|
|
if b_end > len(data): |
|
76
|
|
|
try: |
|
77
|
|
|
self.socket.send(data[b_start:]) |
|
78
|
|
|
except (socket.error, BrokenPipeError) as e: |
|
79
|
|
|
logger.error("Error sending data to the client. %s", e) |
|
80
|
|
|
ret_success = False |
|
81
|
|
|
finally: |
|
82
|
|
|
return ret_success # pylint: disable=lost-exception |
|
83
|
|
|
|
|
84
|
|
|
try: |
|
85
|
|
|
b_sent = self.socket.send(data[b_start:b_end]) |
|
86
|
|
|
except (socket.error, BrokenPipeError) as e: |
|
87
|
|
|
logger.error("Error sending data to the client. %s", e) |
|
88
|
|
|
return False |
|
89
|
|
|
|
|
90
|
|
|
b_start = b_end |
|
91
|
|
|
b_end += b_sent |
|
92
|
|
|
|
|
93
|
|
|
return ret_success |
|
94
|
|
|
|
|
95
|
|
|
|
|
96
|
|
|
StreamCallbackType = Callable[[Stream], None] |
|
97
|
|
|
|
|
98
|
|
|
InetAddress = Tuple[str, int] |
|
99
|
|
|
|
|
100
|
|
|
|
|
101
|
|
|
def validate_cacert_file(cacert: str): |
|
102
|
|
|
""" Check if provided file is a valid CA Certificate """ |
|
103
|
|
|
try: |
|
104
|
|
|
context = ssl.create_default_context(cafile=cacert) |
|
105
|
|
|
except AttributeError: |
|
106
|
|
|
# Python version < 2.7.9 |
|
107
|
|
|
return |
|
108
|
|
|
except IOError: |
|
109
|
|
|
raise OspdError('CA Certificate not found') |
|
110
|
|
|
|
|
111
|
|
|
try: |
|
112
|
|
|
not_after = context.get_ca_certs()[0]['notAfter'] |
|
113
|
|
|
not_after = ssl.cert_time_to_seconds(not_after) |
|
114
|
|
|
not_before = context.get_ca_certs()[0]['notBefore'] |
|
115
|
|
|
not_before = ssl.cert_time_to_seconds(not_before) |
|
116
|
|
|
except (KeyError, IndexError): |
|
117
|
|
|
raise OspdError('CA Certificate is erroneous') |
|
118
|
|
|
|
|
119
|
|
|
now = int(time.time()) |
|
120
|
|
|
if not_after < now: |
|
121
|
|
|
raise OspdError('CA Certificate expired') |
|
122
|
|
|
|
|
123
|
|
|
if not_before > now: |
|
124
|
|
|
raise OspdError('CA Certificate not active yet') |
|
125
|
|
|
|
|
126
|
|
|
|
|
127
|
|
|
class RequestHandler(socketserver.BaseRequestHandler): |
|
128
|
|
|
""" Class to handle the request.""" |
|
129
|
|
|
|
|
130
|
|
|
def handle(self): |
|
131
|
|
|
self.server.handle_request(self.request, self.client_address) |
|
132
|
|
|
|
|
133
|
|
|
|
|
134
|
|
|
class BaseServer(ABC): |
|
135
|
|
|
def __init__(self, stream_timeout: int): |
|
136
|
|
|
self.server = None |
|
137
|
|
|
self.stream_timeout = stream_timeout |
|
138
|
|
|
|
|
139
|
|
|
@abstractmethod |
|
140
|
|
|
def start(self, stream_callback: StreamCallbackType): |
|
141
|
|
|
""" Starts a server with capabilities to handle multiple client |
|
142
|
|
|
connections simultaneously. |
|
143
|
|
|
If a new client connects the stream_callback is called with a Stream |
|
144
|
|
|
|
|
145
|
|
|
Arguments: |
|
146
|
|
|
stream_callback (function): Callback function to be called when |
|
147
|
|
|
a stream is ready |
|
148
|
|
|
""" |
|
149
|
|
|
|
|
150
|
|
|
def close(self): |
|
151
|
|
|
""" Shutdown the server""" |
|
152
|
|
|
self.server.shutdown() |
|
153
|
|
|
self.server.server_close() |
|
154
|
|
|
|
|
155
|
|
|
@abstractmethod |
|
156
|
|
|
def handle_request(self, request, client_address): |
|
157
|
|
|
""" Handle an incoming client request""" |
|
158
|
|
|
|
|
159
|
|
|
def _start_threading_server(self): |
|
160
|
|
|
server_thread = threading.Thread(target=self.server.serve_forever) |
|
161
|
|
|
server_thread.daemon = True |
|
162
|
|
|
server_thread.start() |
|
163
|
|
|
|
|
164
|
|
|
|
|
165
|
|
|
class SocketServerMixin: |
|
166
|
|
|
# Use daemon mode to circrumvent a memory leak |
|
167
|
|
|
# (reported at https://bugs.python.org/issue37193). |
|
168
|
|
|
# |
|
169
|
|
|
# Daemonic threads are killed immediately by the python interpreter without |
|
170
|
|
|
# waiting for until they are finished. |
|
171
|
|
|
# |
|
172
|
|
|
# Maybe block_on_close = True could work too. |
|
173
|
|
|
# In that case the interpreter waits for the threads to finish but doesn't |
|
174
|
|
|
# track them in the _threads list. |
|
175
|
|
|
daemon_threads = True |
|
176
|
|
|
|
|
177
|
|
|
def __init__(self, server: BaseServer, address: Union[str, InetAddress]): |
|
178
|
|
|
self.server = server |
|
179
|
|
|
super().__init__(address, RequestHandler, bind_and_activate=True) |
|
180
|
|
|
|
|
181
|
|
|
def handle_request(self, request, client_address): |
|
182
|
|
|
self.server.handle_request(request, client_address) |
|
183
|
|
|
|
|
184
|
|
|
|
|
185
|
|
|
class ThreadedUnixSocketServer( |
|
186
|
|
|
SocketServerMixin, socketserver.ThreadingUnixStreamServer, |
|
187
|
|
|
): |
|
188
|
|
|
pass |
|
189
|
|
|
|
|
190
|
|
|
|
|
191
|
|
|
class ThreadedTlsSocketServer( |
|
192
|
|
|
SocketServerMixin, socketserver.ThreadingTCPServer, |
|
193
|
|
|
): |
|
194
|
|
|
pass |
|
195
|
|
|
|
|
196
|
|
|
|
|
197
|
|
|
class UnixSocketServer(BaseServer): |
|
198
|
|
|
""" Server for accepting connections via a Unix domain socket |
|
199
|
|
|
""" |
|
200
|
|
|
|
|
201
|
|
|
def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int): |
|
202
|
|
|
super().__init__(stream_timeout) |
|
203
|
|
|
self.socket_path = Path(socket_path) |
|
204
|
|
|
self.socket_mode = int(socket_mode, 8) |
|
205
|
|
|
|
|
206
|
|
|
def _cleanup_socket(self): |
|
207
|
|
|
if self.socket_path.exists(): |
|
208
|
|
|
self.socket_path.unlink() |
|
209
|
|
|
|
|
210
|
|
|
def _create_parent_dirs(self): |
|
211
|
|
|
# create all parent directories for the socket path |
|
212
|
|
|
parent = self.socket_path.parent |
|
213
|
|
|
parent.mkdir(parents=True, exist_ok=True) |
|
214
|
|
|
|
|
215
|
|
|
def start(self, stream_callback: StreamCallbackType): |
|
216
|
|
|
self._cleanup_socket() |
|
217
|
|
|
self._create_parent_dirs() |
|
218
|
|
|
|
|
219
|
|
|
try: |
|
220
|
|
|
self.stream_callback = stream_callback |
|
221
|
|
|
self.server = ThreadedUnixSocketServer(self, str(self.socket_path)) |
|
222
|
|
|
self._start_threading_server() |
|
223
|
|
|
except OSError as e: |
|
224
|
|
|
logger.error("Couldn't bind socket on %s", str(self.socket_path)) |
|
225
|
|
|
raise OspdError( |
|
226
|
|
|
"Couldn't bind socket on {}. {}".format( |
|
227
|
|
|
str(self.socket_path), e |
|
228
|
|
|
) |
|
229
|
|
|
) |
|
230
|
|
|
|
|
231
|
|
|
if self.socket_path.exists(): |
|
232
|
|
|
self.socket_path.chmod(self.socket_mode) |
|
233
|
|
|
|
|
234
|
|
|
def close(self): |
|
235
|
|
|
super().close() |
|
236
|
|
|
self._cleanup_socket() |
|
237
|
|
|
|
|
238
|
|
|
def handle_request(self, request, client_address): |
|
239
|
|
|
logger.debug("New request from %s", str(self.socket_path)) |
|
240
|
|
|
|
|
241
|
|
|
stream = Stream(request, self.stream_timeout) |
|
242
|
|
|
self.stream_callback(stream) |
|
243
|
|
|
|
|
244
|
|
|
|
|
245
|
|
|
class TlsServer(BaseServer): |
|
246
|
|
|
""" Server for accepting TLS encrypted connections via a TCP socket |
|
247
|
|
|
""" |
|
248
|
|
|
|
|
249
|
|
|
def __init__( |
|
250
|
|
|
self, |
|
251
|
|
|
address: str, |
|
252
|
|
|
port: int, |
|
253
|
|
|
cert_file: str, |
|
254
|
|
|
key_file: str, |
|
255
|
|
|
ca_file: str, |
|
256
|
|
|
stream_timeout: int, |
|
257
|
|
|
): |
|
258
|
|
|
super().__init__(stream_timeout) |
|
259
|
|
|
self.socket = (address, port) |
|
260
|
|
|
|
|
261
|
|
|
if not Path(cert_file).exists(): |
|
262
|
|
|
raise OspdError('cert file {} not found'.format(cert_file)) |
|
263
|
|
|
|
|
264
|
|
|
if not Path(key_file).exists(): |
|
265
|
|
|
raise OspdError('key file {} not found'.format(key_file)) |
|
266
|
|
|
|
|
267
|
|
|
if not Path(ca_file).exists(): |
|
268
|
|
|
raise OspdError('CA file {} not found'.format(ca_file)) |
|
269
|
|
|
|
|
270
|
|
|
validate_cacert_file(ca_file) |
|
271
|
|
|
|
|
272
|
|
|
protocol = ssl.PROTOCOL_SSLv23 |
|
273
|
|
|
self.tls_context = ssl.SSLContext(protocol) |
|
274
|
|
|
self.tls_context.verify_mode = ssl.CERT_REQUIRED |
|
275
|
|
|
|
|
276
|
|
|
self.tls_context.load_cert_chain(cert_file, keyfile=key_file) |
|
277
|
|
|
self.tls_context.load_verify_locations(ca_file) |
|
278
|
|
|
|
|
279
|
|
|
def start(self, stream_callback: StreamCallbackType): |
|
280
|
|
|
try: |
|
281
|
|
|
self.stream_callback = stream_callback |
|
282
|
|
|
self.server = ThreadedTlsSocketServer(self, self.socket) |
|
283
|
|
|
self._start_threading_server() |
|
284
|
|
|
except OSError as e: |
|
285
|
|
|
logger.error( |
|
286
|
|
|
"Couldn't bind socket on %s:%s", self.socket[0], self.socket[1] |
|
287
|
|
|
) |
|
288
|
|
|
raise OspdError( |
|
289
|
|
|
"Couldn't bind socket on {}:{}. {}".format( |
|
290
|
|
|
self.socket[0], str(self.socket[1]), e |
|
291
|
|
|
) |
|
292
|
|
|
) |
|
293
|
|
|
|
|
294
|
|
|
def handle_request(self, request, client_address): |
|
295
|
|
|
logger.debug("New connection from %s", client_address) |
|
296
|
|
|
|
|
297
|
|
|
req_socket = self.tls_context.wrap_socket(request, server_side=True) |
|
298
|
|
|
|
|
299
|
|
|
stream = Stream(req_socket, self.stream_timeout) |
|
300
|
|
|
self.stream_callback(stream) |
|
301
|
|
|
|