1
|
|
|
# Copyright (C) 2014-2020 Greenbone Networks GmbH |
2
|
|
|
# |
3
|
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later |
4
|
|
|
# |
5
|
|
|
# This program is free software: you can redistribute it and/or modify |
6
|
|
|
# it under the terms of the GNU Affero General Public License as |
7
|
|
|
# published by the Free Software Foundation, either version 3 of the |
8
|
|
|
# License, or (at your option) any later version. |
9
|
|
|
# |
10
|
|
|
# This program is distributed in the hope that it will be useful, |
11
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
12
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
13
|
|
|
# GNU Affero General Public License for more details. |
14
|
|
|
# |
15
|
|
|
# You should have received a copy of the GNU Affero General Public License |
16
|
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
17
|
|
|
|
18
|
|
|
""" |
19
|
|
|
Module for serving and streaming data |
20
|
|
|
""" |
21
|
|
|
|
22
|
|
|
import logging |
23
|
|
|
import socket |
24
|
|
|
import ssl |
25
|
|
|
import time |
26
|
|
|
import threading |
27
|
|
|
import socketserver |
28
|
|
|
|
29
|
|
|
from abc import ABC, abstractmethod |
30
|
|
|
from pathlib import Path |
31
|
|
|
from typing import Callable, Optional, Tuple, Union |
32
|
|
|
|
33
|
|
|
from ospd.errors import OspdError |
34
|
|
|
|
35
|
|
|
logger = logging.getLogger(__name__) |
36
|
|
|
|
37
|
|
|
DEFAULT_BUFSIZE = 1024 |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
class Stream: |
41
|
|
|
def __init__(self, sock: socket.socket, stream_timeout: int): |
42
|
|
|
self.socket = sock |
43
|
|
|
self.socket.settimeout(stream_timeout) |
44
|
|
|
|
45
|
|
|
def close(self): |
46
|
|
|
""" Close the stream |
47
|
|
|
""" |
48
|
|
|
try: |
49
|
|
|
self.socket.shutdown(socket.SHUT_RDWR) |
50
|
|
|
except OSError as e: |
51
|
|
|
logger.debug( |
52
|
|
|
"Ignoring error while shutting down the connection. %s", e |
53
|
|
|
) |
54
|
|
|
|
55
|
|
|
self.socket.close() |
56
|
|
|
|
57
|
|
|
def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes: |
58
|
|
|
""" Read at maximum bufsize data from the stream |
59
|
|
|
""" |
60
|
|
|
data = self.socket.recv(bufsize) |
61
|
|
|
|
62
|
|
|
if not data: |
63
|
|
|
logger.debug('Client closed the connection') |
64
|
|
|
|
65
|
|
|
return data |
66
|
|
|
|
67
|
|
|
def write(self, data: bytes) -> bool: |
68
|
|
|
""" Send data in chunks of DEFAULT_BUFSIZE to the client |
69
|
|
|
""" |
70
|
|
|
b_start = 0 |
71
|
|
|
b_end = DEFAULT_BUFSIZE |
72
|
|
|
ret_success = True |
73
|
|
|
|
74
|
|
|
while True: |
75
|
|
|
if b_end > len(data): |
76
|
|
|
try: |
77
|
|
|
self.socket.send(data[b_start:]) |
78
|
|
|
except (socket.error, BrokenPipeError) as e: |
79
|
|
|
logger.error("Error sending data to the client. %s", e) |
80
|
|
|
ret_success = False |
81
|
|
|
finally: |
82
|
|
|
return ret_success # pylint: disable=lost-exception |
83
|
|
|
|
84
|
|
|
try: |
85
|
|
|
b_sent = self.socket.send(data[b_start:b_end]) |
86
|
|
|
except (socket.error, BrokenPipeError) as e: |
87
|
|
|
logger.error("Error sending data to the client. %s", e) |
88
|
|
|
return False |
89
|
|
|
|
90
|
|
|
b_start = b_end |
91
|
|
|
b_end += b_sent |
92
|
|
|
|
93
|
|
|
return ret_success |
94
|
|
|
|
95
|
|
|
|
96
|
|
|
StreamCallbackType = Callable[[Stream], None] |
97
|
|
|
|
98
|
|
|
InetAddress = Tuple[str, int] |
99
|
|
|
|
100
|
|
|
|
101
|
|
|
def validate_cacert_file(cacert: str): |
102
|
|
|
""" Check if provided file is a valid CA Certificate """ |
103
|
|
|
try: |
104
|
|
|
context = ssl.create_default_context(cafile=cacert) |
105
|
|
|
except AttributeError: |
106
|
|
|
# Python version < 2.7.9 |
107
|
|
|
return |
108
|
|
|
except IOError: |
109
|
|
|
raise OspdError('CA Certificate not found') |
110
|
|
|
|
111
|
|
|
try: |
112
|
|
|
not_after = context.get_ca_certs()[0]['notAfter'] |
113
|
|
|
not_after = ssl.cert_time_to_seconds(not_after) |
114
|
|
|
not_before = context.get_ca_certs()[0]['notBefore'] |
115
|
|
|
not_before = ssl.cert_time_to_seconds(not_before) |
116
|
|
|
except (KeyError, IndexError): |
117
|
|
|
raise OspdError('CA Certificate is erroneous') |
118
|
|
|
|
119
|
|
|
now = int(time.time()) |
120
|
|
|
if not_after < now: |
121
|
|
|
raise OspdError('CA Certificate expired') |
122
|
|
|
|
123
|
|
|
if not_before > now: |
124
|
|
|
raise OspdError('CA Certificate not active yet') |
125
|
|
|
|
126
|
|
|
|
127
|
|
|
class RequestHandler(socketserver.BaseRequestHandler): |
128
|
|
|
""" Class to handle the request.""" |
129
|
|
|
|
130
|
|
|
def handle(self): |
131
|
|
|
self.server.handle_request(self.request, self.client_address) |
132
|
|
|
|
133
|
|
|
|
134
|
|
|
class BaseServer(ABC): |
135
|
|
|
def __init__(self, stream_timeout: int): |
136
|
|
|
self.server = None |
137
|
|
|
self.stream_timeout = stream_timeout |
138
|
|
|
|
139
|
|
|
@abstractmethod |
140
|
|
|
def start(self, stream_callback: StreamCallbackType): |
141
|
|
|
""" Starts a server with capabilities to handle multiple client |
142
|
|
|
connections simultaneously. |
143
|
|
|
If a new client connects the stream_callback is called with a Stream |
144
|
|
|
|
145
|
|
|
Arguments: |
146
|
|
|
stream_callback (function): Callback function to be called when |
147
|
|
|
a stream is ready |
148
|
|
|
""" |
149
|
|
|
|
150
|
|
|
def close(self): |
151
|
|
|
""" Shutdown the server""" |
152
|
|
|
self.server.shutdown() |
153
|
|
|
self.server.server_close() |
154
|
|
|
|
155
|
|
|
@abstractmethod |
156
|
|
|
def handle_request(self, request, client_address): |
157
|
|
|
""" Handle an incoming client request""" |
158
|
|
|
|
159
|
|
|
def _start_threading_server(self): |
160
|
|
|
server_thread = threading.Thread(target=self.server.serve_forever) |
161
|
|
|
server_thread.daemon = True |
162
|
|
|
server_thread.start() |
163
|
|
|
|
164
|
|
|
|
165
|
|
|
class SocketServerMixin: |
166
|
|
|
# Use daemon mode to circrumvent a memory leak |
167
|
|
|
# (reported at https://bugs.python.org/issue37193). |
168
|
|
|
# |
169
|
|
|
# Daemonic threads are killed immediately by the python interpreter without |
170
|
|
|
# waiting for until they are finished. |
171
|
|
|
# |
172
|
|
|
# Maybe block_on_close = True could work too. |
173
|
|
|
# In that case the interpreter waits for the threads to finish but doesn't |
174
|
|
|
# track them in the _threads list. |
175
|
|
|
daemon_threads = True |
176
|
|
|
|
177
|
|
|
def __init__(self, server: BaseServer, address: Union[str, InetAddress]): |
178
|
|
|
self.server = server |
179
|
|
|
super().__init__(address, RequestHandler, bind_and_activate=True) |
180
|
|
|
|
181
|
|
|
def handle_request(self, request, client_address): |
182
|
|
|
self.server.handle_request(request, client_address) |
183
|
|
|
|
184
|
|
|
|
185
|
|
|
class ThreadedUnixSocketServer( |
186
|
|
|
SocketServerMixin, socketserver.ThreadingUnixStreamServer, |
187
|
|
|
): |
188
|
|
|
pass |
189
|
|
|
|
190
|
|
|
|
191
|
|
|
class ThreadedTlsSocketServer( |
192
|
|
|
SocketServerMixin, socketserver.ThreadingTCPServer, |
193
|
|
|
): |
194
|
|
|
pass |
195
|
|
|
|
196
|
|
|
|
197
|
|
|
class UnixSocketServer(BaseServer): |
198
|
|
|
""" Server for accepting connections via a Unix domain socket |
199
|
|
|
""" |
200
|
|
|
|
201
|
|
|
def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int): |
202
|
|
|
super().__init__(stream_timeout) |
203
|
|
|
self.socket_path = Path(socket_path) |
204
|
|
|
self.socket_mode = int(socket_mode, 8) |
205
|
|
|
|
206
|
|
|
def _cleanup_socket(self): |
207
|
|
|
if self.socket_path.exists(): |
208
|
|
|
self.socket_path.unlink() |
209
|
|
|
|
210
|
|
|
def _create_parent_dirs(self): |
211
|
|
|
# create all parent directories for the socket path |
212
|
|
|
parent = self.socket_path.parent |
213
|
|
|
parent.mkdir(parents=True, exist_ok=True) |
214
|
|
|
|
215
|
|
|
def start(self, stream_callback: StreamCallbackType): |
216
|
|
|
self._cleanup_socket() |
217
|
|
|
self._create_parent_dirs() |
218
|
|
|
|
219
|
|
|
try: |
220
|
|
|
self.stream_callback = stream_callback |
221
|
|
|
self.server = ThreadedUnixSocketServer(self, str(self.socket_path)) |
222
|
|
|
self._start_threading_server() |
223
|
|
|
except OSError as e: |
224
|
|
|
logger.error("Couldn't bind socket on %s", str(self.socket_path)) |
225
|
|
|
raise OspdError( |
226
|
|
|
"Couldn't bind socket on {}. {}".format( |
227
|
|
|
str(self.socket_path), e |
228
|
|
|
) |
229
|
|
|
) |
230
|
|
|
|
231
|
|
|
if self.socket_path.exists(): |
232
|
|
|
self.socket_path.chmod(self.socket_mode) |
233
|
|
|
|
234
|
|
|
def close(self): |
235
|
|
|
super().close() |
236
|
|
|
self._cleanup_socket() |
237
|
|
|
|
238
|
|
|
def handle_request(self, request, client_address): |
239
|
|
|
logger.debug("New request from %s", str(self.socket_path)) |
240
|
|
|
|
241
|
|
|
stream = Stream(request, self.stream_timeout) |
242
|
|
|
self.stream_callback(stream) |
243
|
|
|
|
244
|
|
|
|
245
|
|
|
class TlsServer(BaseServer): |
246
|
|
|
""" Server for accepting TLS encrypted connections via a TCP socket |
247
|
|
|
""" |
248
|
|
|
|
249
|
|
|
def __init__( |
250
|
|
|
self, |
251
|
|
|
address: str, |
252
|
|
|
port: int, |
253
|
|
|
cert_file: str, |
254
|
|
|
key_file: str, |
255
|
|
|
ca_file: str, |
256
|
|
|
stream_timeout: int, |
257
|
|
|
): |
258
|
|
|
super().__init__(stream_timeout) |
259
|
|
|
self.socket = (address, port) |
260
|
|
|
|
261
|
|
|
if not Path(cert_file).exists(): |
262
|
|
|
raise OspdError('cert file {} not found'.format(cert_file)) |
263
|
|
|
|
264
|
|
|
if not Path(key_file).exists(): |
265
|
|
|
raise OspdError('key file {} not found'.format(key_file)) |
266
|
|
|
|
267
|
|
|
if not Path(ca_file).exists(): |
268
|
|
|
raise OspdError('CA file {} not found'.format(ca_file)) |
269
|
|
|
|
270
|
|
|
validate_cacert_file(ca_file) |
271
|
|
|
|
272
|
|
|
protocol = ssl.PROTOCOL_SSLv23 |
273
|
|
|
self.tls_context = ssl.SSLContext(protocol) |
274
|
|
|
self.tls_context.verify_mode = ssl.CERT_REQUIRED |
275
|
|
|
|
276
|
|
|
self.tls_context.load_cert_chain(cert_file, keyfile=key_file) |
277
|
|
|
self.tls_context.load_verify_locations(ca_file) |
278
|
|
|
|
279
|
|
|
def start(self, stream_callback: StreamCallbackType): |
280
|
|
|
try: |
281
|
|
|
self.stream_callback = stream_callback |
282
|
|
|
self.server = ThreadedTlsSocketServer(self, self.socket) |
283
|
|
|
self._start_threading_server() |
284
|
|
|
except OSError as e: |
285
|
|
|
logger.error( |
286
|
|
|
"Couldn't bind socket on %s:%s", self.socket[0], self.socket[1] |
287
|
|
|
) |
288
|
|
|
raise OspdError( |
289
|
|
|
"Couldn't bind socket on {}:{}. {}".format( |
290
|
|
|
self.socket[0], str(self.socket[1]), e |
291
|
|
|
) |
292
|
|
|
) |
293
|
|
|
|
294
|
|
|
def handle_request(self, request, client_address): |
295
|
|
|
logger.debug("New connection from %s", client_address) |
296
|
|
|
|
297
|
|
|
req_socket = self.tls_context.wrap_socket(request, server_side=True) |
298
|
|
|
|
299
|
|
|
stream = Stream(req_socket, self.stream_timeout) |
300
|
|
|
self.stream_callback(stream) |
301
|
|
|
|