Passed
Pull Request — master (#136)
by Juan José
01:17
created

ospd.server.TlsServer.bind()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 12
nop 1
dl 0
loc 16
rs 9.8
c 0
b 0
f 0
1
# Copyright (C) 2019 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
"""
19
Module for serving and streaming data
20
"""
21
22
import logging
23
import socket
24
import ssl
25
import time
26
import os
27
import threading
28
import socketserver
29
30
from abc import ABC, abstractmethod
31
from pathlib import Path
32
from typing import Callable, Optional
33
34
from ospd.errors import OspdError
35
36
logger = logging.getLogger(__name__)
37
38
DEFAULT_BUFSIZE = 1024
39
40
class Stream:
41
    def __init__(self, sock: socket.socket, stream_timeout: int):
42
        self.socket = sock
43
        self.socket.settimeout(stream_timeout)
44
45
    def close(self):
46
        """ Close the stream
47
        """
48
        self.socket.shutdown(socket.SHUT_RDWR)
49
        self.socket.close()
50
51
    def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
52
        """ Read at maximum bufsize data from the stream
53
        """
54
        data = self.socket.recv(bufsize)
55
56
        if not data:
57
            logger.debug('Client closed the connection')
58
59
        return data
60
61
    def write(self, data: bytes):
62
        """ Send data in chunks of DEFAULT_BUFSIZE to the client
63
        """
64
        b_start = 0
65
        b_end = DEFAULT_BUFSIZE
66
67
        while True:
68
            if b_end > len(data):
69
                self.socket.send(data[b_start:])
70
                break
71
72
            try:
73
                b_sent = self.socket.send(data[b_start:b_end])
74
            except socket.error as e:
75
                logger.error(
76
                    "Error sending data to the client. %s", e
77
                )
78
                return
79
            b_start = b_end
80
            b_end += b_sent
81
82
83
StreamCallbackType = Callable[[Stream], None]
84
85
def validate_cacert_file(cacert: str):
86
    """ Check if provided file is a valid CA Certificate """
87
    try:
88
        context = ssl.create_default_context(cafile=cacert)
89
    except AttributeError:
90
        # Python version < 2.7.9
91
        return
92
    except IOError:
93
        raise OspdError('CA Certificate not found')
94
95
    try:
96
        not_after = context.get_ca_certs()[0]['notAfter']
97
        not_after = ssl.cert_time_to_seconds(not_after)
98
        not_before = context.get_ca_certs()[0]['notBefore']
99
        not_before = ssl.cert_time_to_seconds(not_before)
100
    except (KeyError, IndexError):
101
        raise OspdError('CA Certificate is erroneous')
102
103
    now = int(time.time())
104
    if not_after < now:
105
        raise OspdError('CA Certificate expired')
106
107
    if not_before > now:
108
        raise OspdError('CA Certificate not active yet')
109
110
111
def start_server(stream_callback, stream_timeout, newsocket, tls_ctx=None):
112
    """ Starts listening and creates a new thread for each new client
113
    connection.
114
    Arguments:
115
            stream_callback (function): Callback function to be called when
116
                a stream is ready
117
            newsocket (path to socket or socket tuple): The tuple with
118
                address and port or the path to the socket for unix domain
119
                sockets.
120
    Returns the created server object.
121
    """
122
    class ThreadedRequestHandler(socketserver.BaseRequestHandler):
123
        """ Class to handle the request."""
124
125
        def handle(self):
126
            if tls_ctx:
127
                logger.debug(
128
                    "New connection from" " %s:%s", newsocket[0], newsocket[1]
129
                )
130
                req_socket = tls_ctx.wrap_socket(self.request, server_side=True)
131
            else:
132
                req_socket = self.request
133
                logger.debug("New connection from %s", newsocket)
134
135
            stream = Stream(req_socket, stream_timeout)
136
            stream_callback(stream)
137
138
    class ThreadedUnixSockServer(
139
            socketserver.ThreadingMixIn,
140
            socketserver.UnixStreamServer,
141
    ):
142
        pass
143
144
    class ThreadedTlsSockServer(
145
            socketserver.ThreadingMixIn,
146
            socketserver.TCPServer,
147
    ):
148
        pass
149
150
    if tls_ctx:
151
        try:
152
            server = ThreadedTlsSockServer(newsocket, ThreadedRequestHandler)
153
        except OSError as e:
154
            logger.error(
155
                "Couldn't bind socket on %s:%s", newsocket[0], newsocket[1]
156
            )
157
            raise OspdError(
158
                "Couldn't bind socket on {}:{}. {}".format(
159
                    newsocket[0], str(newsocket[1]), e,
160
            ))
161
    else:
162
        try:
163
            server = ThreadedUnixSockServer(
164
                str(newsocket), ThreadedRequestHandler
165
            )
166
        except OSError as e:
167
            logger.error("Couldn't bind socket on %s", str(newsocket))
168
            raise OspdError(
169
                "Couldn't bind socket on {}. {}".format(str(newsocket), e)
170
            )
171
172
173
    server_thread = threading.Thread(target=server.serve_forever)
174
    server_thread.daemon = True
175
    server_thread.start()
176
177
    return server
178
179
180
class BaseServer(ABC):
181
    def __init__(self, stream_timeout):
182
        self.server = None
183
        self.stream_timeout = stream_timeout
184
185
    @abstractmethod
186
    def start(
187
        self,
188
        stream_callback: StreamCallbackType,
189
    ):
190
        """ Starts a server with capabilities to handle multiple client
191
        connections simultaneously.
192
        If a new client connects the stream_callback is called with a Stream
193
        Arguments:
194
            stream_callback (function): Callback function to be called when
195
                a stream is ready
196
        """
197
198
    def close(self):
199
        """ Shutdown the server"""
200
        self.server.shutdown()
201
        self.server.server_close()
202
203
204
class UnixSocketServer(BaseServer):
205
    """ Server for accepting connections via a Unix domain socket
206
    """
207
208
    def __init__(self, socket_path, socket_mode, stream_timeout: int):
209
        super().__init__(stream_timeout)
210
        self.socket_path = Path(socket_path)
211
        self.socket_mode = int(socket_mode, 8)
212
213
    def _cleanup_socket(self):
214
        if self.socket_path.exists():
215
            self.socket_path.unlink()
216
217
    def _create_parent_dirs(self):
218
        # create all parent directories for the socket path
219
        parent = self.socket_path.parent
220
        parent.mkdir(parents=True, exist_ok=True)
221
222
    def start(
223
            self,
224
            stream_callback: StreamCallbackType,
225
    ):
226
        self._cleanup_socket()
227
        self._create_parent_dirs()
228
229
        self.server = start_server(
230
            stream_callback,
231
            self.stream_timeout,
232
            self.socket_path
233
        )
234
235
        if self.socket_path.exists():
236
            os.chmod(str(self.socket_path), self.socket_mode)
237
238
    def close(self):
239
        super().close()
240
        self._cleanup_socket()
241
242
class TlsServer(BaseServer):
243
    """ Server for accepting TLS encrypted connections via a TCP socket
244
    """
245
246
    def __init__(
247
        self,
248
        address: str,
249
        port: int,
250
        cert_file: str,
251
        key_file: str,
252
        ca_file: str,
253
        stream_timeout: int,
254
    ):
255
        super().__init__(stream_timeout)
256
        self.socket = (address, port)
257
258
259
        if not Path(cert_file).exists():
260
            raise OspdError('cert file {} not found'.format(cert_file))
261
262
        if not Path(key_file).exists():
263
            raise OspdError('key file {} not found'.format(key_file))
264
265
        if not Path(ca_file).exists():
266
            raise OspdError('CA file {} not found'.format(ca_file))
267
268
        validate_cacert_file(ca_file)
269
270
        # Despite the name, ssl.PROTOCOL_SSLv23 selects the highest
271
        # protocol version that both the client and server support. In modern
272
        # Python versions (>= 3.4) it supports TLS >= 1.0 with SSLv2 and SSLv3
273
        # being disabled. For Python > 3.5, PROTOCOL_SSLv23 is an alias for
274
        # PROTOCOL_TLS which should be used once compatibility with Python 3.5
275
        # is no longer desired.
276
277
        if hasattr(ssl, 'PROTOCOL_TLS'):
278
            protocol = ssl.PROTOCOL_TLS
279
        else:
280
            protocol = ssl.PROTOCOL_SSLv23
281
282
        self.tls_context = ssl.SSLContext(protocol)
283
        self.tls_context.verify_mode = ssl.CERT_REQUIRED
284
285
        self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
286
        self.tls_context.load_verify_locations(ca_file)
287
288
    def start(
289
        self,
290
        stream_callback: StreamCallbackType,
291
    ):
292
        self.server = start_server(
293
            stream_callback,
294
            self.stream_timeout,
295
            self.socket,
296
            tls_ctx=self.tls_context
297
        )
298