Passed
Push — main ( ed7d21...87238c )
by Douglas
01:43
created

pocketutils.core.smartio.SmartIo.write()   C

Complexity

Conditions 9

Size

Total Lines 29
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 29
rs 6.6666
c 0
b 0
f 0
cc 9
nop 7
1
"""
2
Compression-aware reading and writing of files.
3
"""
4
from __future__ import annotations
5
6
import bz2
7
import gzip
8
import lzma
9
import os
10
from dataclasses import dataclass
11
from datetime import datetime
12
from pathlib import Path, PurePath
13
from typing import TYPE_CHECKING, Any, Self, TypeVar
14
15
from pocketutils.core.exceptions import WritePermissionsError, XFileExistsError
16
17
if TYPE_CHECKING:
18
    from collections.abc import Callable, Mapping
19
20
PathLike = str | PurePath
21
T = TypeVar("T")
22
23
24
@dataclass(frozen=True, slots=True)
25
class Compression:
26
    name: str
27
    suffixes: list[str]
28
    compress: Callable[[bytes], bytes]
0 ignored issues
show
introduced by
The variable Callable does not seem to be defined in case TYPE_CHECKING on line 17 is False. Are you sure this can never be the case?
Loading history...
29
    decompress: Callable[[bytes], bytes]
30
31
    def compress_file(self: Self, source: PurePath | str, dest: PurePath | str | None = None) -> None:
32
        source = Path(source)
33
        dest = source.parent / (source.name + self.suffixes[0]) if dest is None else Path(dest)
34
        data = self.compress(source.read_bytes())
35
        dest.write_bytes(data)
36
37
    def decompress_file(self: Self, source: PurePath | str, dest: PurePath | str | None = None) -> None:
38
        source = Path(source)
39
        dest = source.with_suffix("") if dest is None else Path(dest)
40
        data = self.decompress(source.read_bytes())
41
        dest.write_bytes(data)
42
43
44
def identity(x: T) -> T:
45
    return x
46
47
48
@dataclass(frozen=True, slots=True)
49
class CompressionSet:
50
    mapping: dict[str, Compression]
51
52
    @classmethod
53
    def empty(cls: type[Self]) -> Self:
54
        return CompressionSet({"": Compression("", [], identity, identity)})
55
56
    def __add__(self: Self, fmt: Compression):
57
        new = {fmt.name: fmt} | {s: fmt for s in fmt.suffixes}
58
        already = {v for k, v in self.mapping.items() if k in new}
59
        if len(already) > 1 or len(already) == 1 and already != {fmt}:
60
            msg = f"Keys from {fmt} already mapped to {already}"
61
            raise ValueError(msg)
62
        return CompressionSet(self.mapping | new)
63
64
    def __sub__(self: Self, fmt: Compression) -> CompressionSet:
65
        return CompressionSet(
66
            {k: v for k, v in self.mapping.items() if k != fmt.name and k not in fmt.suffixes},
67
        )
68
69
    def __or__(self: Self, fmt: CompressionSet) -> CompressionSet:
70
        return CompressionSet(self.mapping | fmt.mapping)
71
72
    def __getitem__(self: Self, t: Compression | str) -> Compression:
73
        """
74
        Returns a FileFormat from a name (e.g. "gz" or "gzip").
75
        Case-insensitive.
76
77
        Example:
78
            `Compression.of("gzip").suffix  # ".gz"`
79
        """
80
        if isinstance(t, Compression):
81
            return t
82
        return self.mapping[t]
83
84
    def guess(self: Self, path: PathLike) -> Compression:
85
        if "." not in path.name:
86
            return self[""]
87
        try:
88
            return self[path.suffix]
89
        except KeyError:
90
            return self[""]
91
92
93
def _get_compressions() -> CompressionSet:
94
    import brotli
95
    import lz4.frame
96
    import snappy
97
    import zstandard
98
99
    return (
100
        CompressionSet.empty()
101
        + Compression("gzip", [".gz", ".gzip"], gzip.compress, gzip.decompress)
102
        + Compression("brotli", [".br", ".brotli"], brotli.compress, brotli.decompress)
103
        + Compression("zstandard", [".zst", ".zstd"], zstandard.compress, zstandard.decompress)
104
        + Compression("lz4", [".lz4"], lz4.frame.compress, lz4.frame.decompress)
105
        + Compression("snappy", [".snappy"], snappy.compress, snappy.decompress)
106
        + Compression("bzip2", [".bz2", ".bzip2"], bz2.compress, bz2.decompress)
107
        + Compression("xz", [".xz"], lzma.compress, lzma.decompress)
108
        + Compression("lzma", [".lzma"], lzma.compress, lzma.decompress)
109
    )
110
111
112
@dataclass(frozen=True, slots=True)
113
class SmartIo:
114
    __COMPRESSIONS = None
115
116
    @classmethod
117
    def mapping(cls: type[Self]) -> Mapping[str, Compression]:
118
        return cls.compressions().mapping
119
120
    @classmethod
121
    def compressions(cls: type[Self]) -> CompressionSet:
122
        if cls.__COMPRESSIONS is None:
123
            _COMPRESSIONS = _get_compressions()
124
        return cls.__COMPRESSIONS
125
126
    @classmethod
127
    def write(
128
        cls: type[Self],
129
        data: Any,
130
        path: PathLike,
131
        *,
132
        atomic: bool = False,
133
        mkdirs: bool = False,
134
        exist_ok: bool = False,
135
    ) -> None:
136
        path = Path(path)
137
        compressed = cls.compressions().guess(path).compress(data)
138
        if path.exists() and not path.is_file():
139
            msg = f"Path {path} is not a file"
140
            raise WritePermissionsError(msg, path=path)
141
        if path.exists() and not exist_ok:
142
            msg = f"Path {path} exists"
143
            raise XFileExistsError(msg, path=path)
144
        if path.exists() and not os.access(path, os.W_OK):
145
            msg = f"Cannot write to {path}"
146
            raise WritePermissionsError(msg, path=path)
147
        if mkdirs:
148
            path.parent.mkdir(parents=True, exist_ok=True)
149
        if atomic:
150
            tmp = cls.tmp_path(path)
151
            path.write_bytes(compressed)
152
            tmp.rename(path)
153
        else:
154
            path.write_bytes(compressed)
155
156
    @classmethod
157
    def read_text(cls: type[Self], path: PathLike, encoding: str = "utf-8") -> str:
158
        """
159
        Similar to :meth:`read_bytes`, but then converts to UTF-8.
160
        """
161
        return cls.read_bytes(path).decode(encoding=encoding)
162
163
    @classmethod
164
    def read_bytes(cls: type[Self], path: PathLike) -> bytes:
165
        """
166
        Reads, decompressing according to the filename suffix.
167
        """
168
        data = Path(path).read_bytes()
169
        return cls.compressions().guess(path).decompress(data)
170
171
    @classmethod
172
    def tmp_path(cls: type[Self], path: PathLike, extra: str = "tmp") -> Path:
173
        now = datetime.now().isoformat(timespec="microsecond")
174
        now = now.replace(":", "").replace("-", "")
175
        path = Path(path)
176
        suffix = "".join(path.suffixes)
177
        return path.parent / f".part_{extra}.{now}{suffix}"
178
179
180
__all__ = ["Compression", "CompressionSet", "SmartIo"]
181