Passed
Pull Request — master (#139)
by
unknown
01:27
created

ospd.server.TlsServer.handle_request()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 3
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2019 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
"""
19
Module for serving and streaming data
20
"""
21
22
import logging
23
import socket
24
import ssl
25
import time
26
import os
27
import threading
28
import socketserver
29
30
from abc import ABC, abstractmethod
31
from pathlib import Path
32
from typing import Callable, Optional, Tuple, Union
33
34
from ospd.errors import OspdError
35
36
logger = logging.getLogger(__name__)
37
38
DEFAULT_BUFSIZE = 1024
39
40
41
class Stream:
42
    def __init__(self, sock: socket.socket, stream_timeout: int):
43
        self.socket = sock
44
        self.socket.settimeout(stream_timeout)
45
46
    def close(self):
47
        """ Close the stream
48
        """
49
        self.socket.shutdown(socket.SHUT_RDWR)
50
        self.socket.close()
51
52
    def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
53
        """ Read at maximum bufsize data from the stream
54
        """
55
        data = self.socket.recv(bufsize)
56
57
        if not data:
58
            logger.debug('Client closed the connection')
59
60
        return data
61
62
    def write(self, data: bytes):
63
        """ Send data in chunks of DEFAULT_BUFSIZE to the client
64
        """
65
        b_start = 0
66
        b_end = DEFAULT_BUFSIZE
67
68
        while True:
69
            if b_end > len(data):
70
                self.socket.send(data[b_start:])
71
                break
72
73
            try:
74
                b_sent = self.socket.send(data[b_start:b_end])
75
            except socket.error as e:
76
                logger.error("Error sending data to the client. %s", e)
77
                return
78
            b_start = b_end
79
            b_end += b_sent
80
81
82
StreamCallbackType = Callable[[Stream], None]
83
84
InetAddress = Tuple[str, int]
85
86
87
def validate_cacert_file(cacert: str):
88
    """ Check if provided file is a valid CA Certificate """
89
    try:
90
        context = ssl.create_default_context(cafile=cacert)
91
    except AttributeError:
92
        # Python version < 2.7.9
93
        return
94
    except IOError:
95
        raise OspdError('CA Certificate not found')
96
97
    try:
98
        not_after = context.get_ca_certs()[0]['notAfter']
99
        not_after = ssl.cert_time_to_seconds(not_after)
100
        not_before = context.get_ca_certs()[0]['notBefore']
101
        not_before = ssl.cert_time_to_seconds(not_before)
102
    except (KeyError, IndexError):
103
        raise OspdError('CA Certificate is erroneous')
104
105
    now = int(time.time())
106
    if not_after < now:
107
        raise OspdError('CA Certificate expired')
108
109
    if not_before > now:
110
        raise OspdError('CA Certificate not active yet')
111
112
113
class RequestHandler(socketserver.BaseRequestHandler):
114
    """ Class to handle the request."""
115
116
    def handle(self):
117
        self.server.handle_request(self.request, self.client_address)
118
119
120
class BaseServer(ABC):
121
    def __init__(self, stream_timeout: int):
122
        self.server = None
123
        self.stream_timeout = stream_timeout
124
125
    @abstractmethod
126
    def start(self, stream_callback: StreamCallbackType):
127
        """ Starts a server with capabilities to handle multiple client
128
        connections simultaneously.
129
        If a new client connects the stream_callback is called with a Stream
130
131
        Arguments:
132
            stream_callback (function): Callback function to be called when
133
                a stream is ready
134
        """
135
136
    def close(self):
137
        """ Shutdown the server"""
138
        self.server.shutdown()
139
        self.server.server_close()
140
141
    @abstractmethod
142
    def handle_request(self, request, client_address):
143
        """ Handle an incomming client request"""
144
145
    def _start_threading_server(self):
146
        server_thread = threading.Thread(target=self.server.serve_forever)
147
        server_thread.daemon = True
148
        server_thread.start()
149
150
151
class SocketServerMixin:
152
    def __init__(self, server: BaseServer, address: Union[str, InetAddress]):
153
        self.server = server
154
        super().__init__(address, RequestHandler, bind_and_activate=True)
155
156
    def handle_request(self, request, client_address):
157
        self.server.handle_request(request, client_address)
158
159
160
class ThreadedUnixSocketServer(
161
    SocketServerMixin,
162
    socketserver.ThreadingMixIn,
163
    socketserver.UnixStreamServer,
164
):
165
    pass
166
167
168
class ThreadedTlsSocketServer(
169
    SocketServerMixin, socketserver.ThreadingMixIn, socketserver.TCPServer
170
):
171
    pass
172
173
174
class UnixSocketServer(BaseServer):
175
    """ Server for accepting connections via a Unix domain socket
176
    """
177
178
    def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int):
179
        super().__init__(stream_timeout)
180
        self.socket_path = Path(socket_path)
181
        self.socket_mode = int(socket_mode, 8)
182
183
    def _cleanup_socket(self):
184
        if self.socket_path.exists():
185
            self.socket_path.unlink()
186
187
    def _create_parent_dirs(self):
188
        # create all parent directories for the socket path
189
        parent = self.socket_path.parent
190
        parent.mkdir(parents=True, exist_ok=True)
191
192
    def start(self, stream_callback: StreamCallbackType):
193
        self._cleanup_socket()
194
        self._create_parent_dirs()
195
196
        if self.socket_path.exists():
197
            os.chmod(str(self.socket_path), self.socket_mode)
198
199
        try:
200
            self.stream_callback = stream_callback
201
            self.server = ThreadedUnixSocketServer(self, str(self.socket_path))
202
            self._start_threading_server()
203
        except OSError as e:
204
            logger.error("Couldn't bind socket on %s", str(self.socket_path))
205
            raise OspdError(
206
                "Couldn't bind socket on {}. {}".format(
207
                    str(self.socket_path), e
208
                )
209
            )
210
211
    def close(self):
212
        super().close()
213
        self._cleanup_socket()
214
215
    def handle_request(self, request, client_address):
216
        logger.debug("New connection from %s", str(self.socket_path))
217
218
        stream = Stream(request, self.stream_timeout)
219
        self.stream_callback(stream)
220
221
222
class TlsServer(BaseServer):
223
    """ Server for accepting TLS encrypted connections via a TCP socket
224
    """
225
226
    def __init__(
227
        self,
228
        address: str,
229
        port: int,
230
        cert_file: str,
231
        key_file: str,
232
        ca_file: str,
233
        stream_timeout: int,
234
    ):
235
        super().__init__(stream_timeout)
236
        self.socket = (address, port)
237
238
        if not Path(cert_file).exists():
239
            raise OspdError('cert file {} not found'.format(cert_file))
240
241
        if not Path(key_file).exists():
242
            raise OspdError('key file {} not found'.format(key_file))
243
244
        if not Path(ca_file).exists():
245
            raise OspdError('CA file {} not found'.format(ca_file))
246
247
        validate_cacert_file(ca_file)
248
249
        # Despite the name, ssl.PROTOCOL_SSLv23 selects the highest
250
        # protocol version that both the client and server support. In modern
251
        # Python versions (>= 3.4) it supports TLS >= 1.0 with SSLv2 and SSLv3
252
        # being disabled. For Python > 3.5, PROTOCOL_SSLv23 is an alias for
253
        # PROTOCOL_TLS which should be used once compatibility with Python 3.5
254
        # is no longer desired.
255
256
        if hasattr(ssl, 'PROTOCOL_TLS'):
257
            protocol = ssl.PROTOCOL_TLS
258
        else:
259
            protocol = ssl.PROTOCOL_SSLv23
260
261
        self.tls_context = ssl.SSLContext(protocol)
262
        self.tls_context.verify_mode = ssl.CERT_REQUIRED
263
264
        self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
265
        self.tls_context.load_verify_locations(ca_file)
266
267
    def start(self, stream_callback: StreamCallbackType):
268
        try:
269
            self.stream_callback = stream_callback
270
            self.server = ThreadedTlsSocketServer(self, self.socket)
271
            self._start_threading_server()
272
        except OSError as e:
273
            logger.error(
274
                "Couldn't bind socket on %s:%s", self.socket[0], self.socket[1]
275
            )
276
            raise OspdError(
277
                "Couldn't bind socket on {}:{}. {}".format(
278
                    self.socket[0], str(self.socket[1]), e
279
                )
280
            )
281
282
    def handle_request(self, request, client_address):
283
        logger.debug("New connection from %s", client_address)
284
285
        req_socket = self.tls_context.wrap_socket(request, server_side=True)
286
287
        stream = Stream(req_socket, self.stream_timeout)
288
        self.stream_callback(stream)
289