Passed
Push — main ( ed7d21...87238c )
by Douglas
01:43
created

PollingRestDownloader.__call__()   B

Complexity

Conditions 7

Size

Total Lines 15
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 15
rs 8
c 0
b 0
f 0
cc 7
nop 1
1
import asyncio
2
import tempfile
3
import time
4
from collections.abc import Awaitable, Callable
5
from dataclasses import dataclass
6
from pathlib import Path
7
from typing import Any, Self
8
9
from httpx import AsyncClient, Response
10
from orjson import orjson
11
12
13
@dataclass(frozen=True, slots=True, kw_only=True)
14
class PollingRestDownloader:
15
    """
16
    For example:
17
18
    ``python
19
        return client.post(
20
            url,
21
            params=params,
22
            data=data,
23
            files=files,
24
            timeout=self.timeout_sec,
25
            follow_redirects=True,
26
        )
27
    ``
28
    """
29
30
    post: Callable[[AsyncClient], Awaitable[Response]]
31
    get: Callable[[AsyncClient, dict[str, str | list[str] | dict[str]]], Awaitable[Response]]
32
    initial_wait_sec: float = 20
33
    subsequent_wait_sec: float = 5
34
    max_wait_sec: float = 120
35
36
    def _get(self: Self, client: AsyncClient, data: dict[str, Any]) -> Awaitable[Response]:
37
        raise NotImplementedError()
38
39
    async def __call__(self: Self) -> dict:
40
        async with AsyncClient(http2=True) as client:
41
            response = await self.post(client)
42
            response.raise_for_status()
43
            data = orjson.loads(response.text)
44
            response = None
45
            t0 = time.monotonic()
46
            while response is None or response.status_code != 200:
47
                if time.monotonic() - t0 > self.max_wait_sec:
48
                    raise TimeoutError()
49
                await asyncio.sleep(self.subsequent_wait_sec if response else self.initial_wait_sec)
50
                response = await self.get(client, data)
51
                if response.status_code not in {200, 404}:
52
                    response.raise_for_status()
53
        return orjson.loads(response)
54
55
56
@dataclass(frozen=True, slots=True)
57
class FileStreamer:
58
    async def __call__(self: Self, url: str, path: Path) -> None:
59
        with tempfile.NamedTemporaryFile(dir=path.parent) as download_file:
60
            async with AsyncClient(http2=True) as client, client.stream(url) as response:
61
                for chunk in response.iter_bytes():
62
                    download_file.write(chunk)
63
                    await asyncio.sleep(0)
64
            Path(download_file.name).rename(path)
65
66
67
@dataclass(frozen=True, slots=True)
68
class RichFileStreamer:
69
    async def __call__(self: Self, url: str, path: Path) -> None:
70
        import rich.progress
71
72
        with tempfile.NamedTemporaryFile(dir=path.parent) as download_file:
73
            async with AsyncClient(http2=True) as client, client.stream(url) as response:
74
                total = int(response.headers["Content-Length"])
75
                with rich.progress.Progress(
76
                    "[progress.percentage]{task.percentage:>3.0f}%",
77
                    rich.progress.BarColumn(bar_width=None),
78
                    rich.progress.DownloadColumn(),
79
                    rich.progress.TransferSpeedColumn(),
80
                ) as progress:
81
                    download_task = progress.add_task("Download", total=total)
82
                    for chunk in response.iter_bytes():
83
                        download_file.write(chunk)
84
                        progress.update(download_task, completed=response.num_bytes_downloaded)
85
                        await asyncio.sleep(0)
86
            Path(download_file.name).rename(path)
87
88
89
__all__ = ["PollingRestDownloader", "FileStreamer", "RichFileStreamer"]
90