|
1
|
|
|
import asyncio |
|
2
|
|
|
import tempfile |
|
3
|
|
|
import time |
|
4
|
|
|
from collections.abc import Awaitable, Callable |
|
5
|
|
|
from dataclasses import dataclass |
|
6
|
|
|
from pathlib import Path |
|
7
|
|
|
from typing import Any, Self |
|
8
|
|
|
|
|
9
|
|
|
from httpx import AsyncClient, Response |
|
10
|
|
|
from orjson import orjson |
|
11
|
|
|
|
|
12
|
|
|
|
|
13
|
|
|
@dataclass(frozen=True, slots=True, kw_only=True) |
|
14
|
|
|
class PollingRestDownloader: |
|
15
|
|
|
""" |
|
16
|
|
|
For example: |
|
17
|
|
|
|
|
18
|
|
|
``python |
|
19
|
|
|
return client.post( |
|
20
|
|
|
url, |
|
21
|
|
|
params=params, |
|
22
|
|
|
data=data, |
|
23
|
|
|
files=files, |
|
24
|
|
|
timeout=self.timeout_sec, |
|
25
|
|
|
follow_redirects=True, |
|
26
|
|
|
) |
|
27
|
|
|
`` |
|
28
|
|
|
""" |
|
29
|
|
|
|
|
30
|
|
|
post: Callable[[AsyncClient], Awaitable[Response]] |
|
31
|
|
|
get: Callable[[AsyncClient, dict[str, str | list[str] | dict[str]]], Awaitable[Response]] |
|
32
|
|
|
initial_wait_sec: float = 20 |
|
33
|
|
|
subsequent_wait_sec: float = 5 |
|
34
|
|
|
max_wait_sec: float = 120 |
|
35
|
|
|
|
|
36
|
|
|
def _get(self: Self, client: AsyncClient, data: dict[str, Any]) -> Awaitable[Response]: |
|
37
|
|
|
raise NotImplementedError() |
|
38
|
|
|
|
|
39
|
|
|
async def __call__(self: Self) -> dict: |
|
40
|
|
|
async with AsyncClient(http2=True) as client: |
|
41
|
|
|
response = await self.post(client) |
|
42
|
|
|
response.raise_for_status() |
|
43
|
|
|
data = orjson.loads(response.text) |
|
44
|
|
|
response = None |
|
45
|
|
|
t0 = time.monotonic() |
|
46
|
|
|
while response is None or response.status_code != 200: |
|
47
|
|
|
if time.monotonic() - t0 > self.max_wait_sec: |
|
48
|
|
|
raise TimeoutError() |
|
49
|
|
|
await asyncio.sleep(self.subsequent_wait_sec if response else self.initial_wait_sec) |
|
50
|
|
|
response = await self.get(client, data) |
|
51
|
|
|
if response.status_code not in {200, 404}: |
|
52
|
|
|
response.raise_for_status() |
|
53
|
|
|
return orjson.loads(response) |
|
54
|
|
|
|
|
55
|
|
|
|
|
56
|
|
|
@dataclass(frozen=True, slots=True) |
|
57
|
|
|
class FileStreamer: |
|
58
|
|
|
async def __call__(self: Self, url: str, path: Path) -> None: |
|
59
|
|
|
with tempfile.NamedTemporaryFile(dir=path.parent) as download_file: |
|
60
|
|
|
async with AsyncClient(http2=True) as client, client.stream(url) as response: |
|
61
|
|
|
for chunk in response.iter_bytes(): |
|
62
|
|
|
download_file.write(chunk) |
|
63
|
|
|
await asyncio.sleep(0) |
|
64
|
|
|
Path(download_file.name).rename(path) |
|
65
|
|
|
|
|
66
|
|
|
|
|
67
|
|
|
@dataclass(frozen=True, slots=True) |
|
68
|
|
|
class RichFileStreamer: |
|
69
|
|
|
async def __call__(self: Self, url: str, path: Path) -> None: |
|
70
|
|
|
import rich.progress |
|
71
|
|
|
|
|
72
|
|
|
with tempfile.NamedTemporaryFile(dir=path.parent) as download_file: |
|
73
|
|
|
async with AsyncClient(http2=True) as client, client.stream(url) as response: |
|
74
|
|
|
total = int(response.headers["Content-Length"]) |
|
75
|
|
|
with rich.progress.Progress( |
|
76
|
|
|
"[progress.percentage]{task.percentage:>3.0f}%", |
|
77
|
|
|
rich.progress.BarColumn(bar_width=None), |
|
78
|
|
|
rich.progress.DownloadColumn(), |
|
79
|
|
|
rich.progress.TransferSpeedColumn(), |
|
80
|
|
|
) as progress: |
|
81
|
|
|
download_task = progress.add_task("Download", total=total) |
|
82
|
|
|
for chunk in response.iter_bytes(): |
|
83
|
|
|
download_file.write(chunk) |
|
84
|
|
|
progress.update(download_task, completed=response.num_bytes_downloaded) |
|
85
|
|
|
await asyncio.sleep(0) |
|
86
|
|
|
Path(download_file.name).rename(path) |
|
87
|
|
|
|
|
88
|
|
|
|
|
89
|
|
|
__all__ = ["PollingRestDownloader", "FileStreamer", "RichFileStreamer"] |
|
90
|
|
|
|