Completed
Push — master ( 5762c0...bae074 )
by
unknown
15s queued 11s
created

ospd.server.Stream.read()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
# Copyright (C) 2019 Greenbone Networks GmbH
2
#
3
# SPDX-License-Identifier: GPL-2.0-or-later
4
#
5
# This program is free software; you can redistribute it and/or
6
# modify it under the terms of the GNU General Public License
7
# as published by the Free Software Foundation; either version 2
8
# of the License, or (at your option) any later version.
9
#
10
# This program is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
# GNU General Public License for more details.
14
#
15
# You should have received a copy of the GNU General Public License
16
# along with this program; if not, write to the Free Software
17
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
"""
19
Module for serving and streaming data
20
"""
21
22
import logging
23
import select
24
import socket
25
import ssl
26
27
from abc import ABC, abstractmethod
28
from pathlib import Path
29
from typing import Callable, Optional
30
31
from ospd.errors import OspdError
32
33
logger = logging.getLogger(__name__)
34
35
36
DEFAULT_STREAM_TIMEOUT = 2  # two seconds
37
DEFAULT_BUFSIZE = 1024
38
39
40
class Stream:
41
    def __init__(self, sock: socket.socket):
42
        self.socket = sock
43
        self.socket.settimeout(DEFAULT_STREAM_TIMEOUT)
44
45
    def close(self):
46
        """ Close the stream
47
        """
48
        self.socket.shutdown(socket.SHUT_RDWR)
49
        self.socket.close()
50
51
    def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
52
        """ Read at maximum bufsize data from the stream
53
        """
54
        data = self.socket.recv(bufsize)
55
56
        if not data:
57
            logger.debug('Client closed the connection')
58
59
        return data
60
61
    def write(self, data: bytes):
62
        """ Send data in chunks of DEFAULT_BUFSIZE to the client
63
        """
64
        b_start = 0
65
        b_end = DEFAULT_BUFSIZE
66
67
        while True:
68
            if b_end > len(data):
69
                self.socket.send(data[b_start:])
70
                break
71
72
            b_sent = self.socket.send(data[b_start:b_end])
73
74
            b_start = b_end
75
            b_end += b_sent
76
77
78
StreamCallbackType = Callable[[Stream], None]
79
80
81
class Server(ABC):
82
    @abstractmethod
83
    def bind(self):
84
        """ Start listening for incomming connections
85
        """
86
87
    @abstractmethod
88
    def select(
89
        self,
90
        stream_callback: StreamCallbackType,
91
        timeout: Optional[float] = None,
92
    ):
93
        """ Wait for incomming connections or until timeout is reached
94
95
        If a new client connects the stream_callback is called with a Stream
96
97
        Arguments:
98
            stream_callback (function): Callback function to be called when
99
                a stream is ready
100
            timeout (float): Timeout in seconds to wait for new streams
101
        """
102
103
104
class BaseServer(Server):
105
    def __init__(self):
106
        self.socket = None
107
108
    @abstractmethod
109
    def _accept(self) -> Stream:
110
        pass
111
112
    def select(
113
        self,
114
        stream_callback: StreamCallbackType,
115
        timeout: Optional[float] = None,
116
    ):
117
        inputs = [self.socket]
118
119
        readable, _, _ = select.select(inputs, [], inputs, timeout)
120
121
        # timeout has fired if readable is empty otherwise a new connection is
122
        # available
123
        if readable:
124
            stream = self._accept()
125
            stream_callback(stream)
126
127
    def close(self):
128
        if self.socket:
129
            self.socket.shutdown(socket.SHUT_RDWR)
130
            self.socket.close()
131
132
133
class UnixSocketServer(BaseServer):
134
    """ Server for accepting connections via a Unix domain socket
135
    """
136
137
    def __init__(self, socket_path: str):
138
        super().__init__()
139
        self.socket_path = Path(socket_path)
140
141
    def _cleanup_socket(self):
142
        if self.socket_path.exists():
143
            self.socket_path.unlink()
144
145
    def _create_parent_dirs(self):
146
        # create all parent directories for the socket path
147
        parent = self.socket_path.parent
148
        parent.mkdir(parents=True, exist_ok=True)
149
150
    def bind(self):
151
        self._cleanup_socket()
152
        self._create_parent_dirs()
153
154
        bindsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
155
156
        try:
157
            bindsocket.bind(str(self.socket_path))
158
        except socket.error:
159
            raise OspdError(
160
                "Couldn't bind socket on {}".format(self.socket_path)
161
            )
162
163
        logger.info(
164
            'Unix domain socket server listening on %s', self.socket_path
165
        )
166
167
        bindsocket.listen(0)
168
        bindsocket.setblocking(False)
169
170
        self.socket = bindsocket
171
172
    def _accept(self) -> Stream:
173
        new_socket, _addr = self.socket.accept()
174
175
        logger.debug("New connection from %s", self.socket_path)
176
177
        return Stream(new_socket)
178
179
    def close(self):
180
        super().close()
181
182
        self._cleanup_socket()
183
184
185
class TlsServer(BaseServer):
186
    """ Server for accepting TLS encrypted connections via a TCP socket
187
    """
188
189
    def __init__(
190
        self,
191
        address: str,
192
        port: int,
193
        cert_file: str,
194
        key_file: str,
195
        ca_file: str,
196
    ):
197
        super().__init__()
198
        self.address = address
199
        self.port = port
200
201
        if not Path(cert_file).exists():
202
            raise OspdError('cert file {} not found'.format(cert_file))
203
204
        if not Path(key_file).exists():
205
            raise OspdError('key file {} not found'.format(key_file))
206
207
        if not Path(ca_file).exists():
208
            raise OspdError('CA file {} not found'.format(ca_file))
209
210
        # Despite the name, ssl.PROTOCOL_SSLv23 selects the highest
211
        # protocol version that both the client and server support. In modern
212
        # Python versions (>= 3.4) it supports TLS >= 1.0 with SSLv2 and SSLv3
213
        # being disabled. For Python > 3.5, PROTOCOL_SSLv23 is an alias for
214
        # PROTOCOL_TLS which should be used once compatibility with Python 3.5
215
        # is no longer desired.
216
217
        if hasattr(ssl, 'PROTOCOL_TLS'):
218
            protocol = ssl.PROTOCOL_TLS
219
        else:
220
            protocol = ssl.PROTOCOL_SSLv23
221
222
        self.tls_context = ssl.SSLContext(protocol)
223
        self.tls_context.verify_mode = ssl.CERT_REQUIRED
224
225
        self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
226
        self.tls_context.load_verify_locations(ca_file)
227
228
    def _accept(self) -> Stream:
229
        new_socket, addr = self.socket.accept()
230
231
        logger.debug("New connection from" " %s:%s", addr[0], addr[1])
232
233
        ssl_socket = self.tls_context.wrap_socket(new_socket, server_side=True)
234
235
        return Stream(ssl_socket)
236
237
    def bind(self):
238
        bindsocket = socket.socket()
239
        try:
240
            bindsocket.bind((self.address, self.port))
241
        except socket.error:
242
            logger.error(
243
                "Couldn't bind socket on %s:%s", self.address, self.port
244
            )
245
            return None
246
247
        logger.info('TLS server listening on %s:%s', self.address, self.port)
248
249
        bindsocket.listen(0)
250
        bindsocket.setblocking(False)
251
252
        self.socket = bindsocket
253