Passed
Push — main ( a6ad84...c049b2 )
by Jochen
01:36
created

weitersager.http.RequestHandler.__init__()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 4
CRAP Score 2.032

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 4
dl 0
loc 8
ccs 4
cts 5
cp 0.8
crap 2.032
rs 10
c 0
b 0
f 0
1
"""
2
weitersager.http
3
~~~~~~~~~~~~~~~~
4
5
HTTP server to receive messages
6
7
:Copyright: 2007-2021 Jochen Kupperschmidt
8
:License: MIT, see LICENSE for details.
9
"""
10
11 1
from dataclasses import dataclass
12 1
from functools import partial
13 1
from http import HTTPStatus
14 1
from http.server import BaseHTTPRequestHandler, HTTPServer
15 1
import json
16 1
import sys
17 1
from typing import Optional, Set
18
19 1
from .config import HttpConfig
20 1
from .signals import message_received
21 1
from .util import log, start_thread
22
23
24 1
@dataclass(frozen=True)
25
class Message:
26 1
    channel: str
27 1
    text: str
28
29
30 1
def parse_json_message(json_data: str) -> Message:
31
    """Extract message from JSON."""
32 1
    data = json.loads(json_data)
33
34 1
    channel = data['channel']
35 1
    text = data['text']
36
37 1
    return Message(channel=channel, text=text)
38
39
40 1
class RequestHandler(BaseHTTPRequestHandler):
41
    """Handler for messages submitted via HTTP."""
42
43 1
    def __init__(
44
        self, *args, api_tokens: Optional[Set[str]] = None, **kwargs
45
    ) -> None:
46 1
        if api_tokens is None:
47
            api_tokens = set()
48 1
        self.api_tokens = api_tokens
49
50 1
        super().__init__(*args, **kwargs)
51
52 1
    def do_POST(self) -> None:
53 1
        valid_api_tokens = self.api_tokens
54 1
        if valid_api_tokens:
55 1
            api_token = self._get_api_token()
56 1
            if not api_token:
57 1
                self.send_response(HTTPStatus.UNAUTHORIZED)
58 1
                self.end_headers()
59 1
                return
60
61 1
            if api_token not in valid_api_tokens:
62 1
                self.send_response(HTTPStatus.FORBIDDEN)
63 1
                self.end_headers()
64 1
                return
65
66 1
        try:
67 1
            content_length = int(self.headers.get('Content-Length', 0))
68 1
            data = self.rfile.read(content_length).decode('utf-8')
69 1
            message = parse_json_message(data)
70 1
        except (KeyError, ValueError):
71 1
            log(f'Invalid message received from {self.address_string()}.')
72 1
            self.send_error(HTTPStatus.BAD_REQUEST)
73 1
            return
74
75 1
        self.send_response(HTTPStatus.ACCEPTED)
76 1
        self.end_headers()
77
78 1
        message_received.send(
79
            channel_name=message.channel,
80
            text=message.text,
81
            source_address=self.client_address,
82
        )
83
84 1
    def _get_api_token(self) -> Optional[str]:
85 1
        authorization_value = self.headers.get('Authorization')
86 1
        if not authorization_value:
87 1
            return None
88
89 1
        prefix = 'Token '
90 1
        if not authorization_value.startswith(prefix):
91
            return None
92
93 1
        return authorization_value[len(prefix) :]
94
95 1
    def version_string(self) -> str:
96
        """Return custom server version string."""
97 1
        return 'Weitersager'
98
99
100 1
def create_server(config: HttpConfig) -> HTTPServer:
101
    """Create the HTTP server."""
102 1
    address = (config.host, config.port)
103 1
    handler_class = partial(RequestHandler, api_tokens=config.api_tokens)
104 1
    return HTTPServer(address, handler_class)
105
106
107 1
def start_receive_server(config: HttpConfig) -> None:
108
    """Start in a separate thread."""
109
    try:
110
        server = create_server(config)
111
    except OSError as e:
112
        sys.stderr.write(f'Error {e.errno:d}: {e.strerror}\n')
113
        sys.stderr.write(
114
            f'Probably no permission to open port {config.port}. '
115
            'Try to specify a port number above 1,024 (or even '
116
            '4,096) and up to 65,535.\n'
117
        )
118
        sys.exit(1)
119
120
    thread_name = server.__class__.__name__
121
    start_thread(server.serve_forever, thread_name)
122
    log('Listening for HTTP requests on {}:{:d}.', *server.server_address)
123