1
|
|
|
import asyncio |
2
|
|
|
import tempfile |
3
|
|
|
import time |
4
|
|
|
from collections.abc import Awaitable, Callable |
5
|
|
|
from dataclasses import dataclass |
6
|
|
|
from pathlib import Path |
7
|
|
|
from typing import Any, Self |
8
|
|
|
|
9
|
|
|
from httpx import AsyncClient, Response |
10
|
|
|
from orjson import orjson |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
@dataclass(frozen=True, slots=True, kw_only=True) |
14
|
|
|
class PollingRestDownloader: |
15
|
|
|
""" |
16
|
|
|
For example: |
17
|
|
|
|
18
|
|
|
``python |
19
|
|
|
return client.post( |
20
|
|
|
url, |
21
|
|
|
params=params, |
22
|
|
|
data=data, |
23
|
|
|
files=files, |
24
|
|
|
timeout=self.timeout_sec, |
25
|
|
|
follow_redirects=True, |
26
|
|
|
) |
27
|
|
|
`` |
28
|
|
|
""" |
29
|
|
|
|
30
|
|
|
post: Callable[[AsyncClient], Awaitable[Response]] |
31
|
|
|
get: Callable[[AsyncClient, dict[str, str | list[str] | dict[str]]], Awaitable[Response]] |
32
|
|
|
initial_wait_sec: float = 20 |
33
|
|
|
subsequent_wait_sec: float = 5 |
34
|
|
|
max_wait_sec: float = 120 |
35
|
|
|
|
36
|
|
|
def _get(self: Self, client: AsyncClient, data: dict[str, Any]) -> Awaitable[Response]: |
37
|
|
|
raise NotImplementedError() |
38
|
|
|
|
39
|
|
|
async def __call__(self: Self) -> dict: |
40
|
|
|
async with AsyncClient(http2=True) as client: |
41
|
|
|
response = await self.post(client) |
42
|
|
|
response.raise_for_status() |
43
|
|
|
data = orjson.loads(response.text) |
44
|
|
|
response = None |
45
|
|
|
t0 = time.monotonic() |
46
|
|
|
while response is None or response.status_code != 200: |
47
|
|
|
if time.monotonic() - t0 > self.max_wait_sec: |
48
|
|
|
raise TimeoutError() |
49
|
|
|
await asyncio.sleep(self.subsequent_wait_sec if response else self.initial_wait_sec) |
50
|
|
|
response = await self.get(client, data) |
51
|
|
|
if response.status_code not in {200, 404}: |
52
|
|
|
response.raise_for_status() |
53
|
|
|
return orjson.loads(response) |
54
|
|
|
|
55
|
|
|
|
56
|
|
|
@dataclass(frozen=True, slots=True) |
57
|
|
|
class FileStreamer: |
58
|
|
|
async def __call__(self: Self, url: str, path: Path) -> None: |
59
|
|
|
with tempfile.NamedTemporaryFile(dir=path.parent) as download_file: |
60
|
|
|
async with AsyncClient(http2=True) as client, client.stream(url) as response: |
61
|
|
|
for chunk in response.iter_bytes(): |
62
|
|
|
download_file.write(chunk) |
63
|
|
|
await asyncio.sleep(0) |
64
|
|
|
Path(download_file.name).rename(path) |
65
|
|
|
|
66
|
|
|
|
67
|
|
|
@dataclass(frozen=True, slots=True) |
68
|
|
|
class RichFileStreamer: |
69
|
|
|
async def __call__(self: Self, url: str, path: Path) -> None: |
70
|
|
|
import rich.progress |
71
|
|
|
|
72
|
|
|
with tempfile.NamedTemporaryFile(dir=path.parent) as download_file: |
73
|
|
|
async with AsyncClient(http2=True) as client, client.stream(url) as response: |
74
|
|
|
total = int(response.headers["Content-Length"]) |
75
|
|
|
with rich.progress.Progress( |
76
|
|
|
"[progress.percentage]{task.percentage:>3.0f}%", |
77
|
|
|
rich.progress.BarColumn(bar_width=None), |
78
|
|
|
rich.progress.DownloadColumn(), |
79
|
|
|
rich.progress.TransferSpeedColumn(), |
80
|
|
|
) as progress: |
81
|
|
|
download_task = progress.add_task("Download", total=total) |
82
|
|
|
for chunk in response.iter_bytes(): |
83
|
|
|
download_file.write(chunk) |
84
|
|
|
progress.update(download_task, completed=response.num_bytes_downloaded) |
85
|
|
|
await asyncio.sleep(0) |
86
|
|
|
Path(download_file.name).rename(path) |
87
|
|
|
|
88
|
|
|
|
89
|
|
|
__all__ = ["PollingRestDownloader", "FileStreamer", "RichFileStreamer"] |
90
|
|
|
|