Passed
Push — main ( 37807b...392e9b )
by Jochen
04:24
created

weitersager.http   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 124
Duplicated Lines 0 %

Test Coverage

Coverage 85.92%

Importance

Changes 0
Metric Value
wmc 14
eloc 78
dl 0
loc 124
ccs 61
cts 71
cp 0.8592
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A RequestHandler.version_string() 0 3 1
A RequestHandler._get_api_token() 0 10 3
A RequestHandler.__init__() 0 3 1
B RequestHandler.do_POST() 0 31 5

3 Functions

Rating   Name   Duplication   Size   Complexity  
A parse_json_message() 0 8 1
A create_server() 0 5 1
A start_receive_server() 0 16 2
1
"""
2
weitersager.http
3
~~~~~~~~~~~~~~~~
4
5
HTTP server to receive messages
6
7
:Copyright: 2007-2021 Jochen Kupperschmidt
8
:License: MIT, see LICENSE for details.
9
"""
10
11 1
from __future__ import annotations
12 1
from dataclasses import dataclass
13 1
from functools import partial
14 1
from http import HTTPStatus
15 1
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
16 1
import json
17 1
import logging
18 1
import sys
19 1
from typing import Optional
20
21 1
from .config import HttpConfig
22 1
from .signals import message_received
23 1
from .util import start_thread
24
25
26 1
logger = logging.getLogger(__name__)
27
28
29 1
@dataclass(frozen=True)
30
class Message:
31 1
    channel: str
32 1
    text: str
33
34
35 1
def parse_json_message(json_data: str) -> Message:
36
    """Extract message from JSON."""
37 1
    data = json.loads(json_data)
38
39 1
    channel = data['channel']
40 1
    text = data['text']
41
42 1
    return Message(channel=channel, text=text)
43
44
45 1
class RequestHandler(BaseHTTPRequestHandler):
46
    """Handler for messages submitted via HTTP."""
47
48 1
    def __init__(self, api_tokens: set[str], *args, **kwargs) -> None:
49 1
        self.api_tokens = api_tokens
50 1
        super().__init__(*args, **kwargs)
51
52 1
    def do_POST(self) -> None:
53 1
        if self.api_tokens:
54 1
            api_token = self._get_api_token()
55 1
            if not api_token:
56 1
                self.send_response(HTTPStatus.UNAUTHORIZED)
57 1
                self.end_headers()
58 1
                return
59
60 1
            if api_token not in self.api_tokens:
61 1
                self.send_response(HTTPStatus.FORBIDDEN)
62 1
                self.end_headers()
63 1
                return
64
65 1
        try:
66 1
            content_length = int(self.headers.get('Content-Length', 0))
67 1
            data = self.rfile.read(content_length).decode('utf-8')
68 1
            message = parse_json_message(data)
69 1
        except (KeyError, ValueError):
70 1
            logger.info(
71
                'Invalid message received from %s.', self.address_string()
72
            )
73 1
            self.send_error(HTTPStatus.BAD_REQUEST)
74 1
            return
75
76 1
        self.send_response(HTTPStatus.ACCEPTED)
77 1
        self.end_headers()
78
79 1
        message_received.send(
80
            channel_name=message.channel,
81
            text=message.text,
82
            source_address=self.client_address,
83
        )
84
85 1
    def _get_api_token(self) -> Optional[str]:
86 1
        authorization_value = self.headers.get('Authorization')
87 1
        if not authorization_value:
88 1
            return None
89
90 1
        prefix = 'Token '
91 1
        if not authorization_value.startswith(prefix):
92
            return None
93
94 1
        return authorization_value[len(prefix) :]
95
96 1
    def version_string(self) -> str:
97
        """Return custom server version string."""
98 1
        return 'Weitersager'
99
100
101 1
def create_server(config: HttpConfig) -> ThreadingHTTPServer:
102
    """Create the HTTP server."""
103 1
    address = (config.host, config.port)
104 1
    handler_class = partial(RequestHandler, config.api_tokens)
105 1
    return ThreadingHTTPServer(address, handler_class)
106
107
108 1
def start_receive_server(config: HttpConfig) -> None:
109
    """Start in a separate thread."""
110
    try:
111
        server = create_server(config)
112
    except OSError as e:
113
        sys.stderr.write(f'Error {e.errno:d}: {e.strerror}\n')
114
        sys.stderr.write(
115
            f'Probably no permission to open port {config.port}. '
116
            'Try to specify a port number above 1,024 (or even '
117
            '4,096) and up to 65,535.\n'
118
        )
119
        sys.exit(1)
120
121
    thread_name = server.__class__.__name__
122
    start_thread(server.serve_forever, thread_name)
123
    logger.info('Listening for HTTP requests on %s:%d.', *server.server_address)
124