ospd.server   A
last analyzed

Complexity

Total Complexity 40

Size/Duplication

Total Lines 298
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 176
dl 0
loc 298
rs 9.2
c 0
b 0
f 0
wmc 40

21 Methods

Rating   Name   Duplication   Size   Complexity  
A Stream.__init__() 0 3 1
A TlsServer.handle_request() 0 7 1
A SocketServerMixin.__init__() 0 3 1
A RequestHandler.handle() 0 2 1
A Stream.read() 0 8 2
A BaseServer.close() 0 4 1
A SocketServerMixin.handle_request() 0 2 1
A BaseServer.__init__() 0 3 1
A BaseServer.handle_request() 0 3 1
A Stream.close() 0 10 2
B Stream.write() 0 26 5
A BaseServer.start() 0 3 1
A BaseServer._start_threading_server() 0 4 1
A TlsServer.start() 0 14 2
A UnixSocketServer._cleanup_socket() 0 3 2
A UnixSocketServer.__init__() 0 4 1
A UnixSocketServer.close() 0 3 1
A UnixSocketServer.handle_request() 0 5 1
A UnixSocketServer._create_parent_dirs() 0 4 1
A TlsServer.__init__() 0 29 4
A UnixSocketServer.start() 0 18 3

1 Function

Rating   Name   Duplication   Size   Complexity  
B validate_cacert_file() 0 24 6

How to fix   Complexity   

Complexity

Complex classes like ospd.server often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (C) 2014-2021 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: AGPL-3.0-or-later
4
#
5
# This program is free software: you can redistribute it and/or modify
6
# it under the terms of the GNU Affero General Public License as
7
# published by the Free Software Foundation, either version 3 of the
8
# License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU Affero General Public License for more details.
14
#
15
# You should have received a copy of the GNU Affero General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
"""
19
Module for serving and streaming data
20
"""
21
22
import logging
23
import socket
24
import ssl
25
import time
26
import threading
27
import socketserver
28
29
from abc import ABC, abstractmethod
30
from pathlib import Path
31
from typing import Callable, Optional, Tuple, Union
32
33
from ospd.errors import OspdError
34
35
logger = logging.getLogger(__name__)
36
37
DEFAULT_BUFSIZE = 1024
38
39
40
class Stream:
41
    def __init__(self, sock: socket.socket, stream_timeout: int):
42
        self.socket = sock
43
        self.socket.settimeout(stream_timeout)
44
45
    def close(self):
46
        """Close the stream"""
47
        try:
48
            self.socket.shutdown(socket.SHUT_RDWR)
49
        except OSError as e:
50
            logger.debug(
51
                "Ignoring error while shutting down the connection. %s", e
52
            )
53
54
        self.socket.close()
55
56
    def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
57
        """Read at maximum bufsize data from the stream"""
58
        data = self.socket.recv(bufsize)
59
60
        if not data:
61
            logger.debug('Client closed the connection')
62
63
        return data
64
65
    def write(self, data: bytes) -> bool:
66
        """Send data in chunks of DEFAULT_BUFSIZE to the client"""
67
        b_start = 0
68
        b_end = DEFAULT_BUFSIZE
69
        ret_success = True
70
71
        while True:
72
            if b_end > len(data):
73
                try:
74
                    self.socket.send(data[b_start:])
75
                except (socket.error, BrokenPipeError) as e:
76
                    logger.error("Error sending data to the client. %s", e)
77
                    ret_success = False
78
                finally:
79
                    return ret_success  # pylint: disable=lost-exception
80
81
            try:
82
                b_sent = self.socket.send(data[b_start:b_end])
83
            except (socket.error, BrokenPipeError) as e:
84
                logger.error("Error sending data to the client. %s", e)
85
                return False
86
87
            b_start = b_end
88
            b_end += b_sent
89
90
        return ret_success
91
92
93
StreamCallbackType = Callable[[Stream], None]
94
95
InetAddress = Tuple[str, int]
96
97
98
def validate_cacert_file(cacert: str):
99
    """Check if provided file is a valid CA Certificate"""
100
    try:
101
        context = ssl.create_default_context(cafile=cacert)
102
    except AttributeError:
103
        # Python version < 2.7.9
104
        return
105
    except IOError:
106
        raise OspdError('CA Certificate not found') from None
107
108
    try:
109
        not_after = context.get_ca_certs()[0]['notAfter']
110
        not_after = ssl.cert_time_to_seconds(not_after)
111
        not_before = context.get_ca_certs()[0]['notBefore']
112
        not_before = ssl.cert_time_to_seconds(not_before)
113
    except (KeyError, IndexError):
114
        raise OspdError('CA Certificate is erroneous') from None
115
116
    now = int(time.time())
117
    if not_after < now:
118
        raise OspdError('CA Certificate expired')
119
120
    if not_before > now:
121
        raise OspdError('CA Certificate not active yet')
122
123
124
class RequestHandler(socketserver.BaseRequestHandler):
125
    """Class to handle the request."""
126
127
    def handle(self):
128
        self.server.handle_request(self.request, self.client_address)
129
130
131
class BaseServer(ABC):
132
    def __init__(self, stream_timeout: int):
133
        self.server = None
134
        self.stream_timeout = stream_timeout
135
136
    @abstractmethod
137
    def start(self, stream_callback: StreamCallbackType):
138
        """Starts a server with capabilities to handle multiple client
139
        connections simultaneously.
140
        If a new client connects the stream_callback is called with a Stream
141
142
        Arguments:
143
            stream_callback (function): Callback function to be called when
144
                a stream is ready
145
        """
146
147
    def close(self):
148
        """Shutdown the server"""
149
        self.server.shutdown()
150
        self.server.server_close()
151
152
    @abstractmethod
153
    def handle_request(self, request, client_address):
154
        """Handle an incoming client request"""
155
156
    def _start_threading_server(self):
157
        server_thread = threading.Thread(target=self.server.serve_forever)
158
        server_thread.daemon = True
159
        server_thread.start()
160
161
162
class SocketServerMixin:
163
    # Use daemon mode to circrumvent a memory leak
164
    # (reported at https://bugs.python.org/issue37193).
165
    #
166
    # Daemonic threads are killed immediately by the python interpreter without
167
    # waiting for until they are finished.
168
    #
169
    # Maybe block_on_close = True could work too.
170
    # In that case the interpreter waits for the threads to finish but doesn't
171
    # track them in the _threads list.
172
    daemon_threads = True
173
174
    def __init__(self, server: BaseServer, address: Union[str, InetAddress]):
175
        self.server = server
176
        super().__init__(address, RequestHandler, bind_and_activate=True)
177
178
    def handle_request(self, request, client_address):
179
        self.server.handle_request(request, client_address)
180
181
182
class ThreadedUnixSocketServer(
183
    SocketServerMixin,
184
    socketserver.ThreadingUnixStreamServer,
185
):
186
    pass
187
188
189
class ThreadedTlsSocketServer(
190
    SocketServerMixin,
191
    socketserver.ThreadingTCPServer,
192
):
193
    pass
194
195
196
class UnixSocketServer(BaseServer):
197
    """Server for accepting connections via a Unix domain socket"""
198
199
    def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int):
200
        super().__init__(stream_timeout)
201
        self.socket_path = Path(socket_path)
202
        self.socket_mode = int(socket_mode, 8)
203
204
    def _cleanup_socket(self):
205
        if self.socket_path.exists():
206
            self.socket_path.unlink()
207
208
    def _create_parent_dirs(self):
209
        # create all parent directories for the socket path
210
        parent = self.socket_path.parent
211
        parent.mkdir(parents=True, exist_ok=True)
212
213
    def start(self, stream_callback: StreamCallbackType):
214
        self._cleanup_socket()
215
        self._create_parent_dirs()
216
217
        try:
218
            self.stream_callback = stream_callback
219
            self.server = ThreadedUnixSocketServer(self, str(self.socket_path))
220
            self._start_threading_server()
221
        except OSError as e:
222
            logger.error("Couldn't bind socket on %s", str(self.socket_path))
223
            raise OspdError(
224
                "Couldn't bind socket on {}. {}".format(
225
                    str(self.socket_path), e
226
                )
227
            ) from e
228
229
        if self.socket_path.exists():
230
            self.socket_path.chmod(self.socket_mode)
231
232
    def close(self):
233
        super().close()
234
        self._cleanup_socket()
235
236
    def handle_request(self, request, client_address):
237
        logger.debug("New request from %s", str(self.socket_path))
238
239
        stream = Stream(request, self.stream_timeout)
240
        self.stream_callback(stream)
241
242
243
class TlsServer(BaseServer):
244
    """Server for accepting TLS encrypted connections via a TCP socket"""
245
246
    def __init__(
247
        self,
248
        address: str,
249
        port: int,
250
        cert_file: str,
251
        key_file: str,
252
        ca_file: str,
253
        stream_timeout: int,
254
    ):
255
        super().__init__(stream_timeout)
256
        self.socket = (address, port)
257
258
        if not Path(cert_file).exists():
259
            raise OspdError('cert file {} not found'.format(cert_file))
260
261
        if not Path(key_file).exists():
262
            raise OspdError('key file {} not found'.format(key_file))
263
264
        if not Path(ca_file).exists():
265
            raise OspdError('CA file {} not found'.format(ca_file))
266
267
        validate_cacert_file(ca_file)
268
269
        protocol = ssl.PROTOCOL_SSLv23
270
        self.tls_context = ssl.SSLContext(protocol)
271
        self.tls_context.verify_mode = ssl.CERT_REQUIRED
272
273
        self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
274
        self.tls_context.load_verify_locations(ca_file)
275
276
    def start(self, stream_callback: StreamCallbackType):
277
        try:
278
            self.stream_callback = stream_callback
279
            self.server = ThreadedTlsSocketServer(self, self.socket)
280
            self._start_threading_server()
281
        except OSError as e:
282
            logger.error(
283
                "Couldn't bind socket on %s:%s", self.socket[0], self.socket[1]
284
            )
285
            raise OspdError(
286
                "Couldn't bind socket on {}:{}. {}".format(
287
                    self.socket[0], str(self.socket[1]), e
288
                )
289
            ) from e
290
291
    def handle_request(self, request, client_address):
292
        logger.debug("New connection from %s", client_address)
293
294
        req_socket = self.tls_context.wrap_socket(request, server_side=True)
295
296
        stream = Stream(req_socket, self.stream_timeout)
297
        self.stream_callback(stream)
298