Passed
Push — main ( ab694e...8e7205 )
by Jochen
02:05
created

weitersager.http.start_receive_server()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 4.916

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 1
dl 0
loc 16
ccs 1
cts 10
cp 0.1
crap 4.916
rs 9.85
c 0
b 0
f 0
1
"""
2
weitersager.http
3
~~~~~~~~~~~~~~~~
4
5
HTTP server to receive messages
6
7
:Copyright: 2007-2022 Jochen Kupperschmidt
8
:License: MIT, see LICENSE for details.
9
"""
10
11 1
from __future__ import annotations
12 1
from dataclasses import dataclass
13 1
from functools import partial
14 1
from http import HTTPStatus
15 1
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
16 1
import json
17 1
import logging
18 1
import sys
19 1
from typing import Optional
20
21 1
from .config import HttpConfig
22 1
from .signals import message_received
23 1
from .util import start_thread
24
25
26 1
logger = logging.getLogger(__name__)
27
28
29 1
@dataclass(frozen=True)
30
class Message:
31 1
    channel: str
32 1
    text: str
33
34
35 1
def parse_json_message(json_data: str) -> Message:
36
    """Extract message from JSON."""
37 1
    data = json.loads(json_data)
38
39 1
    channel = data['channel']
40 1
    text = data['text']
41
42 1
    return Message(channel=channel, text=text)
43
44
45 1
class RequestHandler(BaseHTTPRequestHandler):
46
    """Handler for messages submitted via HTTP."""
47
48 1
    def __init__(self, api_tokens: set[str], *args, **kwargs) -> None:
49 1
        self.api_tokens = api_tokens
50 1
        super().__init__(*args, **kwargs)
51
52 1
    def do_POST(self) -> None:
53 1
        if self.path != '/':
54 1
            self.send_error(HTTPStatus.NOT_FOUND)
55 1
            return
56
57 1
        if self.api_tokens:
58 1
            api_token = self._get_api_token()
59 1
            if not api_token:
60 1
                self.send_error(HTTPStatus.UNAUTHORIZED)
61 1
                return
62
63 1
            if api_token not in self.api_tokens:
64 1
                self.send_error(HTTPStatus.FORBIDDEN)
65 1
                return
66
67 1
        try:
68 1
            content_length = int(self.headers.get('Content-Length', 0))
69 1
            data = self.rfile.read(content_length).decode('utf-8')
70 1
            message = parse_json_message(data)
71 1
        except (KeyError, ValueError):
72 1
            logger.info(
73
                'Invalid message received from %s.', self.address_string()
74
            )
75 1
            self.send_error(HTTPStatus.BAD_REQUEST)
76 1
            return
77
78 1
        self.send_response(HTTPStatus.ACCEPTED)
79 1
        self.end_headers()
80
81 1
        message_received.send(
82
            channel_name=message.channel,
83
            text=message.text,
84
            source_address=self.client_address,
85
        )
86
87 1
    def _get_api_token(self) -> Optional[str]:
88 1
        authorization_value = self.headers.get('Authorization')
89 1
        if not authorization_value:
90 1
            return None
91
92 1
        prefix = 'Token '
93 1
        if not authorization_value.startswith(prefix):
94
            return None
95
96 1
        return authorization_value[len(prefix) :]
97
98 1
    def version_string(self) -> str:
99
        """Return custom server version string."""
100 1
        return 'Weitersager'
101
102
103 1
def create_server(config: HttpConfig) -> ThreadingHTTPServer:
104
    """Create the HTTP server."""
105 1
    address = (config.host, config.port)
106 1
    handler_class = partial(RequestHandler, config.api_tokens)
107 1
    return ThreadingHTTPServer(address, handler_class)
108
109
110 1
def start_receive_server(config: HttpConfig) -> None:
111
    """Start in a separate thread."""
112
    try:
113
        server = create_server(config)
114
    except OSError as e:
115
        sys.stderr.write(f'Error {e.errno:d}: {e.strerror}\n')
116
        sys.stderr.write(
117
            f'Probably no permission to open port {config.port}. '
118
            'Try to specify a port number above 1,024 (or even '
119
            '4,096) and up to 65,535.\n'
120
        )
121
        sys.exit(1)
122
123
    thread_name = server.__class__.__name__
124
    start_thread(server.serve_forever, thread_name)
125
    logger.info('Listening for HTTP requests on %s:%d.', *server.server_address)
126