pincer.core.gateway.Gateway.handle_heartbeat_req()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# Copyright Pincer 2021-Present
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
5
from __future__ import annotations
6
7
from asyncio import create_task, Task, ensure_future, sleep
8
from dataclasses import dataclass
9
from datetime import datetime
10
from itertools import repeat, count, chain
11
import logging
12
from platform import system
13
from random import random
14
from typing import TYPE_CHECKING, Any, Dict, Callable, Optional
15
from zlib import decompressobj
16
17
from aiohttp import (
0 ignored issues
show
introduced by
Unable to import 'aiohttp'
Loading history...
18
    ClientSession,
19
    WSMsgType,
20
    ClientConnectorError,
21
    ClientWebSocketResponse,
22
)
23
24
from . import __package__
25
from ..utils.api_object import APIObject
26
from .._config import GatewayConfig
27
from ..core.dispatch import GatewayDispatch
28
from ..exceptions import (
0 ignored issues
show
Unused Code introduced by
Unused InvalidTokenError imported from exceptions
Loading history...
29
    InvalidTokenError,
30
    GatewayConnectionError,
31
    GatewayError,
32
    UnhandledException,
33
)
34
35
if TYPE_CHECKING:
36
    from ..objects.app.intents import Intents
37
38
    Handler = Callable[[GatewayDispatch], None]
39
40
_log = logging.getLogger(__package__)
41
42
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
43
inflator = decompressobj()
44
45
46
@dataclass
47
class SessionStartLimit(APIObject):
48
    """Session start limit info returned from the `gateway/bot` endpoint"""
49
50
    total: int
51
    remaining: int
52
    reset_after: int
53
    max_concurrency: int
54
55
56
@dataclass
57
class GatewayInfo(APIObject):
58
    """Gateway info returned from the `gateway/bot` endpoint"""
59
60
    url: str
61
    shards: int
62
    session_start_limit: SessionStartLimit
63
64
65
class Gateway:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (18/7)
Loading history...
66
    """The Gateway handles all interactions with the Discord Websocket API.
67
    This also contains the main event loop, and handles the heartbeat.
68
69
    Running the Gateway will create a connection with the
70
    Discord Websocket API on behalf of the provided token.
71
72
    This token must be a bot token.
73
    (Which can be found on
74
    `<https://discord.com/developers/applications/>`_)
75
76
    Parameters
77
    ----------
78
    token : str.
79
        The token for this bot
80
    intents : :class:`~pincer.objects.app.intents.Intents`
81
        The intents to use. More information can be found at
82
        `<https://discord.com/developers/docs/topics/gateway#gateway-intents>`_.
83
    url : str
84
        The gateway url.
85
    shard : int
86
        The ID of the shard to run.
87
    num_shards : int
88
        Number used to route traffic to the current. This should usually be the total
89
        number of shards that will be run. More information at
90
        `<https://discord.com/developers/docs/topics/gateway#sharding>`_.
91
    """
92
93
    def __init__(
94
        self,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
95
        token: str,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
96
        *,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
97
        intents: Intents,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
98
        url: str,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
99
        shard: int,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
100
        num_shards: int,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
101
    ) -> None:
102
        self.token = token
103
        self.intents = intents
104
        self.url = url
105
        self.shard = shard
106
        self.num_shards = num_shards
107
        self.shard_key = [shard, num_shards]
108
109
        self.__dispatch_handlers: Dict[int, Handler] = {
110
            1: self.handle_heartbeat_req,
111
            7: self.handle_reconnect,
112
            9: self.handle_invalid_session,
113
            10: self.identify_and_handle_hello,
114
            11: self.handle_heartbeat,
115
        }
116
117
        # 4000 and 4009 are not included. The client will reconnect when receiving
118
        # either. Code 4000 is also used for internal disconnects that will lead to a
119
        # reconnect.
120
        self.__close_codes: Dict[int, GatewayError] = {
121
            4001: GatewayError("Invalid opcode was sent"),
122
            4002: GatewayError("Invalid payload was sent."),
123
            4003: GatewayError("Payload was sent prior to identifying"),
124
            4004: GatewayError("Token is not valid"),
125
            4005: GatewayError(
126
                "Authentication was sent after client already authenticated"
127
            ),
128
            4007: GatewayError(
129
                "Invalid sequence sent when starting new session"
130
            ),
131
            4008: GatewayError("Client was rate limited"),
132
            4010: GatewayError("Invalid shard"),
133
            4011: GatewayError("Sharding required"),
134
            4012: GatewayError("Invalid API version"),
135
            4013: GatewayError("Invalid intents"),
136
            4014: GatewayError("Disallowed intents"),
137
        }
138
139
        # ClientSession to be used for this Dispatcher
140
        self.__session: Optional[ClientSession] = None
141
142
        # This type `_WSRequestContextManager` isn't exposed by aiohttp.
143
        # `ClientWebSocketResponse` is a parent class.
144
        self.__socket: Optional[ClientWebSocketResponse] = None
145
146
        # Buffer used to store information in transport compression.
147
        self.__buffer = bytearray()
148
149
        # The gateway can be disconnected from Discord. This variable stores if the
150
        # gateway should send a hello or reconnect.
151
        self.__should_resume: bool = False
152
153
        # The sequence number for the last received payload. This is used reconnecting.
154
        self.__sequence_number: int = 0
155
156
        # The heartbeat task
157
        self.__heartbeat_task: Optional[Task] = None
158
159
        # Keeps the Client waiting until the next heartbeat
160
        self.__wait_for_heartbeat: Optional[Task] = None
161
162
        # How long the client should wait between each Heartbeat.
163
        self.__heartbeat_interval: Optional[int] = None
164
165
        # Tracks whether the gateway has received an ack (opcode 11) since the last
166
        # heartbeat.
167
        #   True: An ack has been received
168
        #   False: No ack has been received. Attempt to reconnect with gateway,
169
        self.__has_received_ack: bool = True
170
171
        # Session ID received from `on_ready` event. It is set in the `on_ready`
172
        # middleware. This is used reconnecting.
173
        self.__session_id: Optional[str] = None
174
175
    def __del__(self):
176
        """Delete method ensures all connections are closed"""
177
        if self.__socket:
178
            create_task(self.__socket.close())
179
        if self.__session:
180
            create_task(self.__session.close())
181
182
    async def init_session(self):
183
        """|coro|
184
        Crates the ClientSession. ALWAYS run this function right after initializing
185
        a Gateway.
186
        """
187
        self.__session = ClientSession()
188
189
    def append_handlers(self, handlers: Dict[int, Handler]):
190
        """The Client that uses the handler can append their own methods. The gateway
191
        will run those methods when the specified opcode is received.
192
        """
193
        self.__dispatch_handlers = {**self.__dispatch_handlers, **handlers}
194
195
    def set_session_id(self, _id: str):
196
        """Session id is private for consistency"""
197
        self.__session_id = _id
198
199
    def decompress_msg(self, msg: bytes) -> Optional[str]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
200
        if GatewayConfig.compression == "zlib-payload":
201
            return inflator.decompress(msg)
202
203
        if GatewayConfig.compression == "zlib-stream":
204
            self.__buffer.extend(msg)
205
206
            if len(self.__buffer) < 4 or self.__buffer[-4:] != ZLIB_SUFFIX:
207
                return None
208
209
            msg = inflator.decompress(msg)
210
            self.__buffer = bytearray()
211
            return msg
212
213
        return None
214
215
    async def start_loop(self):
216
        """|coro|
217
        Instantiate the dispatcher, this will create a connection to the
218
        Discord websocket API on behalf of the client whose token has
219
        been passed.
220
        """
221
        for _try in count():
222
            try:
223
                self.__socket = await self.__session.ws_connect(
224
                    GatewayConfig.make_uri(self.url)
225
                )
226
                break
227
            except ClientConnectorError as e:
228
                if _try > GatewayConfig.MAX_RETRIES:
229
                    raise GatewayConnectionError from e
230
231
                _log.warning(
232
                    "%s Could not open websocket with Discord."
233
                    " Retrying in 15 seconds...",
234
                    self.shard_key,
235
                )
236
                await sleep(15)
237
238
        _log.debug("%s Starting event loop...", self.shard_key)
239
        await self.event_loop()
240
241
    async def event_loop(self):
242
        """|coro|
243
        Handles receiving messages and decompressing them if needed
244
        """
245
        async for msg in self.__socket:
246
            if msg.type == WSMsgType.TEXT:
247
                await self.handle_data(msg.data)
248
            elif msg.type == WSMsgType.BINARY:
249
                # Message from transport compression that isn't complete returns None
250
                data = self.decompress_msg(msg.data)
251
                if data:
252
                    await self.handle_data(data)
253
            elif msg.type == WSMsgType.ERROR:
254
                raise GatewayError from self.__socket.exception()
255
256
        # The loop is broken when the gateway stops receiving messages.
257
        # The "error" op codes are in `self.__close_codes`. The rest of the
258
        # close codes are unknown issues (such as an unintended disconnect) so the
259
        # client should reconnect to the gateway.
260
        err = self.__close_codes.get(self.__socket.close_code)
261
262
        if err:
263
            raise err
264
265
        _log.debug(
266
            "%s Disconnected from Gateway due without any errors. Reconnecting.",
267
            self.shard_key,
268
        )
269
        # ensure_future prevents a stack overflow
270
        ensure_future(self.start_loop())
271
272
    async def handle_data(self, data: Dict[Any]):
273
        """|coro|
274
        Method is run when a payload is received from the gateway.
275
        The message is expected to already have been decompressed.
276
        Handling the opcode is forked to the background, so they aren't blocking.
277
        """
278
        payload = GatewayDispatch.from_string(data)
279
280
        # Op code -1 is activated on all payloads
281
        op_negative_one = self.__dispatch_handlers.get(-1)
282
        if op_negative_one:
283
            ensure_future(op_negative_one(payload))
284
285
        _log.debug(
286
            "%s %s GatewayDispatch with opcode %s received",
287
            self.shard_key,
288
            datetime.now(),
289
            payload.op,
290
        )
291
292
        # Many events are sent with a `null` sequence. This sequence should not
293
        # be tracked.
294
        if payload.seq is not None:
295
            self.__sequence_number = payload.seq
296
            _log.debug(
297
                "%s Set sequence number to %s", self.shard_key, payload.seq
298
            )
299
300
        handler = self.__dispatch_handlers.get(payload.op)
301
302
        if handler is None:
303
            raise UnhandledException(
304
                f"Opcode {payload.op} does not have a handler"
305
            )
306
307
        ensure_future(handler(payload))
308
309
    async def handle_heartbeat_req(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
310
        """|coro|
311
        Opcode 1 - Instantly send a heartbeat.
312
        """
313
        self.send_next_heartbeat()
314
315
    async def handle_reconnect(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
316
        """|coro|
317
        Opcode 7 - Reconnect and resume immediately.
318
        """
319
        _log.debug(
320
            "%s Requested to reconnect to Discord. Closing session and attempting to"
321
            " resume...",
322
            self.shard_key,
323
        )
324
325
        self.__should_resume = True
326
        await self.__socket.close()
327
328
    async def handle_invalid_session(self, payload: GatewayDispatch):
329
        """|coro|
330
        Opcode 9 - Invalid connection
331
        Attempt to relog. This is probably because the session was already invalidated
332
        when we tried to reconnect.
333
        """
334
335
        _log.debug(
336
            "%s Invalid session, attempting to %s...",
337
            self.shard_key,
338
            "reconnect" if payload.data else "relog",
339
        )
340
341
        self.__should_resume = payload.data
342
        self.stop_heartbeat()
343
        await self.__socket.close()
344
345
    async def identify_and_handle_hello(self, payload: GatewayDispatch):
346
        """|coro|
347
        Opcode 10 - Hello there general kenobi
348
        Runs when we connect to the gateway for the first time and every time after.
349
        If the client thinks it should reconnect, the opcode 6 resume payload is sent
350
        instead of the opcode 2 hello payload. A new session is only started after a
351
        reconnect if pcode 9 is received.
352
353
        Successful reconnects are handled in the `resumed` middleware.
354
        """
355
        if self.__should_resume:
356
            _log.debug("%s Resuming connection with Discord", self.shard_key)
357
358
            await self.send(
359
                str(
360
                    GatewayDispatch(
361
                        6,
362
                        {
363
                            "token": self.token,
364
                            "session_id": self.__session_id,
365
                            "seq": self.__sequence_number,
366
                        },
367
                    )
368
                )
369
            )
370
            return
371
372
        await self.send(
373
            str(
374
                GatewayDispatch(
375
                    2,
376
                    {
377
                        "token": self.token,
378
                        "intents": self.intents,
379
                        "properties": {
380
                            "$os": system(),
381
                            "$browser": __package__,
382
                            "$device": __package__,
383
                        },
384
                        "compress": GatewayConfig.compressed(),
385
                        "shard": self.shard_key,
386
                    },
387
                )
388
            )
389
        )
390
        self.__heartbeat_interval = payload.data["heartbeat_interval"]
391
392
        self.start_heartbeat()
393
394
    async def handle_heartbeat(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
395
        """|coro|
396
        Opcode 11 - Heartbeat
397
        Track that the heartbeat has been received using shared state (Rustaceans would
398
        be very mad)
399
        """
400
        self.__has_received_ack = True
401
402
    async def send(self, payload: str):
403
        """|coro|
404
        Send a string object to the payload. Most of this method is just logging,
405
        the last line is the only one that matters for functionality.
406
        """
407
        safe_payload = payload.replace(self.token, "%s..." % self.token[:10])
408
409
        if self.__session_id:
410
            safe_payload = safe_payload.replace(
411
                self.__session_id, "%s..." % self.__session_id[:4]
412
            )
413
414
        _log.debug("%s Sending payload: %s", self.shard_key, safe_payload)
415
416
        if self.__socket.closed:
417
            _log.debug(
418
                "%s Socket is closing. Payload not sent.", self.shard_key
419
            )
420
            return
421
422
        await self.__socket.send_str(payload)
423
424
    def start_heartbeat(self):
425
        """
426
        Starts the heartbeat if it is not already running.
427
        """
428
        if not self.__heartbeat_task or self.__heartbeat_task.cancelled():
429
            self.__heartbeat_task = ensure_future(self.__heartbeat_loop())
430
431
    def stop_heartbeat(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
432
        self.__heartbeat_task.cancel()
433
434
    def send_next_heartbeat(self):
435
        """
436
        It is expected to always be waiting for a heartbeat. By canceling that task,
437
        a heartbeat can be sent.
438
        """
439
        self.__wait_for_heartbeat.cancel()
440
441
    async def __heartbeat_loop(self):
442
        """|coro|
443
        The heartbeat is responsible for keeping the connection to Discord alive.
444
445
        Jitter is only random for the first heartbeat. It should be 1 every other
446
        heartbeat.
447
        """
448
        _log.debug("%s Starting heartbeat loop...", self.shard_key)
449
450
        # When waiting for first heartbeat, there hasn't been an ack received yet.
451
        # Set to true so the ack received check doesn't incorrectly fail.
452
        self.__has_received_ack = True
453
454
        for jitter in chain((random(),), repeat(1)):
455
            duration = self.__heartbeat_interval * jitter
456
457
            _log.debug(
458
                "%s %s sending heartbeat in %sms",
459
                self.shard_key,
460
                datetime.now(),
461
                duration,
462
            )
463
464
            # Task is needed so waiting can be cancelled by op code 1
465
            self.__wait_for_heartbeat = create_task(sleep(duration / 1000))
466
467
            await self.__wait_for_heartbeat
468
469
            if not self.__has_received_ack:
470
                # Close code is specified to be anything that is not 1000 in the docs.
471
                _log.debug(
472
                    "%s %s ack not received. Attempting to reconnect."
473
                    " Closing socket with close code 4000. \U0001f480",
474
                    datetime.now(),
475
                    self.shard_key,
476
                )
477
                await self.__socket.close(code=4000)
478
                self.__should_resume = True
479
                # A new loop is started in the background while this one is stopped.
480
                self.stop_heartbeat()
481
                return
482
483
            self.__has_received_ack = False
484
            await self.send(
485
                str(GatewayDispatch(1, data=self.__sequence_number))
486
            )
487
            _log.debug("%s sent heartbeat", self.shard_key)
488