Passed
Push — main ( eee3ee...668435 )
by Jochen
02:07
created

weitersager.http.Application.on_channel_token()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 4.5185

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 3
dl 0
loc 16
ccs 1
cts 7
cp 0.1429
crap 4.5185
rs 9.85
c 0
b 0
f 0
1
"""
2
weitersager.http
3
~~~~~~~~~~~~~~~~
4
5
HTTP server to receive messages
6
7
:Copyright: 2007-2022 Jochen Kupperschmidt
8
:License: MIT, see LICENSE for details.
9
"""
10
11 1
from __future__ import annotations
12 1
from http import HTTPStatus
13 1
import logging
14 1
import sys
15 1
from typing import Optional
16 1
from wsgiref.simple_server import make_server, ServerHandler, WSGIServer
17
18 1
from werkzeug.datastructures import Headers
19 1
from werkzeug.exceptions import abort, HTTPException
20 1
from werkzeug.routing import Map, Rule
21 1
from werkzeug.wrappers import Request, Response
22
23 1
from .config import HttpConfig
24 1
from .signals import message_received
25 1
from .util import start_thread
26
27
28 1
logger = logging.getLogger(__name__)
29
30
31 1
def create_app(
32
    api_tokens: set[str], channel_tokens_to_channel_names: dict[str, str]
33
) -> Application:
34 1
    return Application(api_tokens, channel_tokens_to_channel_names)
35
36
37 1
class Application:
38 1
    def __init__(
39
        self,
40
        api_tokens: set[str],
41
        channel_tokens_to_channel_names: dict[str, str],
42
    ) -> None:
43 1
        self._api_tokens = api_tokens
44 1
        self._channel_tokens_to_channel_names = channel_tokens_to_channel_names
45
46 1
        self._url_map = Map(
47
            [
48
                Rule('/', endpoint='root'),
49
                Rule('/ct/<channel_token>', endpoint='channel_token'),
50
            ]
51
        )
52
53 1
    def __call__(self, environ, start_response):
54 1
        return self.wsgi_app(environ, start_response)
55
56 1
    def wsgi_app(self, environ, start_response):
57 1
        request = Request(environ)
58 1
        response = self.dispatch_request(request)
59 1
        return response(environ, start_response)
60
61 1
    def dispatch_request(self, request: Request):
62 1
        adapter = self._url_map.bind_to_environ(request.environ)
63
64 1
        try:
65 1
            endpoint, values = adapter.match()
66 1
            handler = getattr(self, f'on_{endpoint}')
67 1
            return handler(request, **values)
68 1
        except HTTPException as exc:
69 1
            return exc
70
71 1
    def on_root(self, request: Request) -> Response:
72 1
        if self._api_tokens:
73 1
            api_token = _get_api_token(request.headers)
74 1
            if not api_token:
75 1
                abort(HTTPStatus.UNAUTHORIZED)
76
77 1
            if api_token not in self._api_tokens:
78 1
                abort(HTTPStatus.FORBIDDEN)
79
80 1
        data = _extract_payload(request, {'channel', 'text'})
81
82 1
        message_received.send(
83
            channel_name=data['channel'],
84
            text=data['text'],
85
            source_ip_address=request.remote_addr,
86
        )
87
88 1
        return Response('', status=HTTPStatus.ACCEPTED)
89
90 1
    def on_channel_token(
91
        self, request: Request, channel_token: str
92
    ) -> Response:
93
        channel_name = self._channel_tokens_to_channel_names.get(channel_token)
94
        if channel_name is None:
95
            abort(HTTPStatus.NOT_FOUND)
96
97
        data = _extract_payload(request, {'text'})
98
99
        message_received.send(
100
            channel_name=channel_name,
101
            text=data['text'],
102
            source_ip_address=request.remote_addr,
103
        )
104
105
        return Response('', status=HTTPStatus.ACCEPTED)
106
107
108 1
def _get_api_token(headers: Headers) -> Optional[str]:
109 1
    authorization_value = headers.get('Authorization')
110 1
    if not authorization_value:
111 1
        return None
112
113 1
    prefix = 'Token '
114 1
    if not authorization_value.startswith(prefix):
115
        return None
116
117 1
    return authorization_value[len(prefix) :]
118
119
120 1
def _extract_payload(request: Request, keys: set[str]) -> dict[str, str]:
121
    """Extract values for given keys from JSON payload."""
122 1
    if not request.is_json:
123
        abort(HTTPStatus.UNSUPPORTED_MEDIA_TYPE)
124
125 1
    payload = request.json
126 1
    if payload is None:
127
        abort(HTTPStatus.BAD_REQUEST)
128
129 1
    data = {}
130 1
    try:
131 1
        for key in keys:
132 1
            data[key] = payload[key]
133 1
    except KeyError:
134 1
        abort(HTTPStatus.BAD_REQUEST)
135
136 1
    return data
137
138
139
# Override value of `Server:` header sent by wsgiref.
140 1
ServerHandler.server_software = 'Weitersager'
141
142
143 1
def create_server(config: HttpConfig) -> WSGIServer:
144
    """Create the HTTP server."""
145 1
    app = create_app(config.api_tokens, config.channel_tokens_to_channel_names)
146
147 1
    return make_server(config.host, config.port, app)
148
149
150 1
def start_receive_server(config: HttpConfig) -> None:
151
    """Start in a separate thread."""
152
    try:
153
        server = create_server(config)
154
    except OSError as e:
155
        sys.stderr.write(f'Error {e.errno:d}: {e.strerror}\n')
156
        sys.stderr.write(
157
            f'Probably no permission to open port {config.port}. '
158
            'Try to specify a port number above 1,024 (or even '
159
            '4,096) and up to 65,535.\n'
160
        )
161
        sys.exit(1)
162
163
    thread_name = server.__class__.__name__
164
    start_thread(server.serve_forever, thread_name)
165
    logger.info('Listening for HTTP requests on %s:%d.', *server.server_address)
166