Completed
Push — master ( 922de2...d10835 )
by
unknown
19s queued 10s
created

ospd.server.TlsServer.__init__()   A

Complexity

Conditions 4

Size

Total Lines 29
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 22
nop 7
dl 0
loc 29
rs 9.352
c 0
b 0
f 0
1
# Copyright (C) 2014-2020 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
"""
19
Module for serving and streaming data
20
"""
21
22
import logging
23
import socket
24
import ssl
25
import time
26
import threading
27
import socketserver
28
29
from abc import ABC, abstractmethod
30
from pathlib import Path
31
from typing import Callable, Optional, Tuple, Union
32
33
from ospd.errors import OspdError
34
35
logger = logging.getLogger(__name__)
36
37
DEFAULT_BUFSIZE = 1024
38
39
40
class Stream:
41
    def __init__(self, sock: socket.socket, stream_timeout: int):
42
        self.socket = sock
43
        self.socket.settimeout(stream_timeout)
44
45
    def close(self):
46
        """ Close the stream
47
        """
48
        try:
49
            self.socket.shutdown(socket.SHUT_RDWR)
50
        except OSError as e:
51
            logger.debug(
52
                "Ignoring error while shutting down the connection. %s", e
53
            )
54
55
        self.socket.close()
56
57
    def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
58
        """ Read at maximum bufsize data from the stream
59
        """
60
        data = self.socket.recv(bufsize)
61
62
        if not data:
63
            logger.debug('Client closed the connection')
64
65
        return data
66
67
    def write(self, data: bytes) -> bool:
68
        """ Send data in chunks of DEFAULT_BUFSIZE to the client
69
        """
70
        b_start = 0
71
        b_end = DEFAULT_BUFSIZE
72
        ret_success = True
73
74
        while True:
75
            if b_end > len(data):
76
                try:
77
                    self.socket.send(data[b_start:])
78
                except (socket.error, BrokenPipeError) as e:
79
                    logger.error("Error sending data to the client. %s", e)
80
                    ret_success = False
81
                finally:
82
                    return ret_success  # pylint: disable=lost-exception
83
84
            try:
85
                b_sent = self.socket.send(data[b_start:b_end])
86
            except (socket.error, BrokenPipeError) as e:
87
                logger.error("Error sending data to the client. %s", e)
88
                return False
89
90
            b_start = b_end
91
            b_end += b_sent
92
93
        return ret_success
94
95
96
StreamCallbackType = Callable[[Stream], None]
97
98
InetAddress = Tuple[str, int]
99
100
101
def validate_cacert_file(cacert: str):
102
    """ Check if provided file is a valid CA Certificate """
103
    try:
104
        context = ssl.create_default_context(cafile=cacert)
105
    except AttributeError:
106
        # Python version < 2.7.9
107
        return
108
    except IOError:
109
        raise OspdError('CA Certificate not found')
110
111
    try:
112
        not_after = context.get_ca_certs()[0]['notAfter']
113
        not_after = ssl.cert_time_to_seconds(not_after)
114
        not_before = context.get_ca_certs()[0]['notBefore']
115
        not_before = ssl.cert_time_to_seconds(not_before)
116
    except (KeyError, IndexError):
117
        raise OspdError('CA Certificate is erroneous')
118
119
    now = int(time.time())
120
    if not_after < now:
121
        raise OspdError('CA Certificate expired')
122
123
    if not_before > now:
124
        raise OspdError('CA Certificate not active yet')
125
126
127
class RequestHandler(socketserver.BaseRequestHandler):
128
    """ Class to handle the request."""
129
130
    def handle(self):
131
        self.server.handle_request(self.request, self.client_address)
132
133
134
class BaseServer(ABC):
135
    def __init__(self, stream_timeout: int):
136
        self.server = None
137
        self.stream_timeout = stream_timeout
138
139
    @abstractmethod
140
    def start(self, stream_callback: StreamCallbackType):
141
        """ Starts a server with capabilities to handle multiple client
142
        connections simultaneously.
143
        If a new client connects the stream_callback is called with a Stream
144
145
        Arguments:
146
            stream_callback (function): Callback function to be called when
147
                a stream is ready
148
        """
149
150
    def close(self):
151
        """ Shutdown the server"""
152
        self.server.shutdown()
153
        self.server.server_close()
154
155
    @abstractmethod
156
    def handle_request(self, request, client_address):
157
        """ Handle an incoming client request"""
158
159
    def _start_threading_server(self):
160
        server_thread = threading.Thread(target=self.server.serve_forever)
161
        server_thread.daemon = True
162
        server_thread.start()
163
164
165
class SocketServerMixin:
166
    # Use daemon mode to circrumvent a memory leak
167
    # (reported at https://bugs.python.org/issue37193).
168
    #
169
    # Daemonic threads are killed immediately by the python interpreter without
170
    # waiting for until they are finished.
171
    #
172
    # Maybe block_on_close = True could work too.
173
    # In that case the interpreter waits for the threads to finish but doesn't
174
    # track them in the _threads list.
175
    daemon_threads = True
176
177
    def __init__(self, server: BaseServer, address: Union[str, InetAddress]):
178
        self.server = server
179
        super().__init__(address, RequestHandler, bind_and_activate=True)
180
181
    def handle_request(self, request, client_address):
182
        self.server.handle_request(request, client_address)
183
184
185
class ThreadedUnixSocketServer(
186
    SocketServerMixin, socketserver.ThreadingUnixStreamServer,
187
):
188
    pass
189
190
191
class ThreadedTlsSocketServer(
192
    SocketServerMixin, socketserver.ThreadingTCPServer,
193
):
194
    pass
195
196
197
class UnixSocketServer(BaseServer):
198
    """ Server for accepting connections via a Unix domain socket
199
    """
200
201
    def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int):
202
        super().__init__(stream_timeout)
203
        self.socket_path = Path(socket_path)
204
        self.socket_mode = int(socket_mode, 8)
205
206
    def _cleanup_socket(self):
207
        if self.socket_path.exists():
208
            self.socket_path.unlink()
209
210
    def _create_parent_dirs(self):
211
        # create all parent directories for the socket path
212
        parent = self.socket_path.parent
213
        parent.mkdir(parents=True, exist_ok=True)
214
215
    def start(self, stream_callback: StreamCallbackType):
216
        self._cleanup_socket()
217
        self._create_parent_dirs()
218
219
        try:
220
            self.stream_callback = stream_callback
221
            self.server = ThreadedUnixSocketServer(self, str(self.socket_path))
222
            self._start_threading_server()
223
        except OSError as e:
224
            logger.error("Couldn't bind socket on %s", str(self.socket_path))
225
            raise OspdError(
226
                "Couldn't bind socket on {}. {}".format(
227
                    str(self.socket_path), e
228
                )
229
            )
230
231
        if self.socket_path.exists():
232
            self.socket_path.chmod(self.socket_mode)
233
234
    def close(self):
235
        super().close()
236
        self._cleanup_socket()
237
238
    def handle_request(self, request, client_address):
239
        logger.debug("New request from %s", str(self.socket_path))
240
241
        stream = Stream(request, self.stream_timeout)
242
        self.stream_callback(stream)
243
244
245
class TlsServer(BaseServer):
246
    """ Server for accepting TLS encrypted connections via a TCP socket
247
    """
248
249
    def __init__(
250
        self,
251
        address: str,
252
        port: int,
253
        cert_file: str,
254
        key_file: str,
255
        ca_file: str,
256
        stream_timeout: int,
257
    ):
258
        super().__init__(stream_timeout)
259
        self.socket = (address, port)
260
261
        if not Path(cert_file).exists():
262
            raise OspdError('cert file {} not found'.format(cert_file))
263
264
        if not Path(key_file).exists():
265
            raise OspdError('key file {} not found'.format(key_file))
266
267
        if not Path(ca_file).exists():
268
            raise OspdError('CA file {} not found'.format(ca_file))
269
270
        validate_cacert_file(ca_file)
271
272
        protocol = ssl.PROTOCOL_SSLv23
273
        self.tls_context = ssl.SSLContext(protocol)
274
        self.tls_context.verify_mode = ssl.CERT_REQUIRED
275
276
        self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
277
        self.tls_context.load_verify_locations(ca_file)
278
279
    def start(self, stream_callback: StreamCallbackType):
280
        try:
281
            self.stream_callback = stream_callback
282
            self.server = ThreadedTlsSocketServer(self, self.socket)
283
            self._start_threading_server()
284
        except OSError as e:
285
            logger.error(
286
                "Couldn't bind socket on %s:%s", self.socket[0], self.socket[1]
287
            )
288
            raise OspdError(
289
                "Couldn't bind socket on {}:{}. {}".format(
290
                    self.socket[0], str(self.socket[1]), e
291
                )
292
            )
293
294
    def handle_request(self, request, client_address):
295
        logger.debug("New connection from %s", client_address)
296
297
        req_socket = self.tls_context.wrap_socket(request, server_side=True)
298
299
        stream = Stream(req_socket, self.stream_timeout)
300
        self.stream_callback(stream)
301