| Total Complexity | 219 |
| Total Lines | 828 |
| Duplicated Lines | 1.45 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like pocketutils.misc.klgists often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import typing |
||
|
|
|||
| 2 | import os, io, shutil, gzip, platform, re |
||
| 3 | from enum import Enum |
||
| 4 | |||
| 5 | from pocketutils.core.exceptions import PathIsNotADirError, MissingResourceError |
||
| 6 | |||
| 7 | import json, sys |
||
| 8 | import unicodedata |
||
| 9 | from itertools import chain |
||
| 10 | import signal |
||
| 11 | import operator |
||
| 12 | from datetime import date, datetime |
||
| 13 | from typing import Callable, TypeVar, Iterator, Iterable, Optional, List, Any, Sequence, Mapping, Dict, ItemsView, KeysView, ValuesView, Tuple, Union |
||
| 14 | from hurry.filesize import size as hsize |
||
| 15 | |||
| 16 | from pocketutils.core.exceptions import LookupFailedError, MultipleMatchesError, ParsingError, ResourceError, MissingConfigKeyError, \ |
||
| 17 | NaturalExpectedError, HashValidationError, FileDoesNotExistError |
||
| 18 | |||
| 19 | import logging |
||
| 20 | import subprocess |
||
| 21 | from subprocess import Popen, PIPE |
||
| 22 | from queue import Queue |
||
| 23 | from threading import Thread |
||
| 24 | from deprecated import deprecated |
||
| 25 | |||
| 26 | import datetime |
||
| 27 | #from datetime import datetime |
||
| 28 | from subprocess import Popen, PIPE |
||
| 29 | from infix import shift_infix, mod_infix, sub_infix |
||
| 30 | |||
| 31 | from colorama import Fore |
||
| 32 | import shutil |
||
| 33 | import argparse |
||
| 34 | |||
| 35 | import hashlib |
||
| 36 | import codecs |
||
| 37 | import gzip |
||
| 38 | |||
| 39 | import struct |
||
| 40 | from array import array |
||
| 41 | import numpy as np |
||
| 42 | |||
| 43 | def blob_to_byte_array(bytes_obj: bytes): |
||
| 44 | return _blob_to_dt(bytes_obj, 'b', 1, np.ubyte) + 128 |
||
| 45 | def blob_to_float_array(bytes_obj: bytes): |
||
| 46 | return _blob_to_dt(bytes_obj, 'f', 4, np.float32) |
||
| 47 | def blob_to_double_array(bytes_obj: bytes): |
||
| 48 | return _blob_to_dt(bytes_obj, 'd', 8, np.float64) |
||
| 49 | def blob_to_short_array(bytes_obj: bytes): |
||
| 50 | return _blob_to_dt(bytes_obj, 'H', 2, np.int16) |
||
| 51 | def blob_to_int_array(bytes_obj: bytes): |
||
| 52 | return _blob_to_dt(bytes_obj, 'I', 4, np.int32) |
||
| 53 | def blob_to_long_array(bytes_obj: bytes): |
||
| 54 | return _blob_to_dt(bytes_obj, 'Q', 8, np.int64) |
||
| 55 | def _blob_to_dt(bytes_obj: bytes, data_type_str: str, data_type_len: int, dtype): |
||
| 56 | return np.array(next(iter(struct.iter_unpack('>' + data_type_str * int(len(bytes_obj)/data_type_len), bytes_obj))), dtype=dtype) |
||
| 57 | |||
| 58 | class FileHasher: |
||
| 59 | |||
| 60 | def __init__(self, algorithm: Callable[[], Any]=hashlib.sha1, extension: str='.sha1', buffer_size: int = 16*1024): |
||
| 61 | self.algorithm = algorithm |
||
| 62 | self.extension = extension |
||
| 63 | self.buffer_size = buffer_size |
||
| 64 | |||
| 65 | def hashsum(self, file_name: str) -> str: |
||
| 66 | alg = self.algorithm() |
||
| 67 | with open(file_name, 'rb') as f: |
||
| 68 | for chunk in iter(lambda: f.read(self.buffer_size), b''): |
||
| 69 | alg.update(chunk) |
||
| 70 | return alg.hexdigest() |
||
| 71 | |||
| 72 | def add_hash(self, file_name: str) -> None: |
||
| 73 | with open(file_name + self.extension, 'w', encoding="utf8") as f: |
||
| 74 | s = self.hashsum(file_name) |
||
| 75 | f.write(s) |
||
| 76 | |||
| 77 | def check_hash(self, file_name: str) -> bool: |
||
| 78 | if not os.path.isfile(file_name + self.extension): return False |
||
| 79 | with open(file_name + self.extension, 'r', encoding="utf8") as f: |
||
| 80 | hash_str = f.read().split()[0] # check only the first thing on the line before any spaces |
||
| 81 | return hash_str == self.hashsum(file_name) |
||
| 82 | |||
| 83 | def check_and_open(self, file_name: str, *args): |
||
| 84 | return self._o(file_name, opener=lambda f: codecs.open(f, encoding='utf-8'), *args) |
||
| 85 | |||
| 86 | def check_and_open_gzip(self, file_name: str, *args): |
||
| 87 | return self._o(file_name, opener=gzip.open, *args) |
||
| 88 | |||
| 89 | def _o(self, file_name: str, opener, *args): |
||
| 90 | if not os.path.isfile(file_name + self.extension): |
||
| 91 | raise FileDoesNotExistError("Hash for file {} does not exist".format(file_name)) |
||
| 92 | with open(file_name + self.extension, 'r', encoding="utf8") as f: |
||
| 93 | if f.read() != self.hashsum(file_name): |
||
| 94 | raise HashValidationError("Hash for file {} does not match".format(file_name)) |
||
| 95 | return opener(file_name, *args) |
||
| 96 | |||
| 97 | def mkdatetime(s: str) -> datetime: |
||
| 98 | return datetime.strptime(s.replace(' ', 'T'), "%Y-%m-%dT%H:%M:%S") |
||
| 99 | |||
| 100 | def now() -> datetime: |
||
| 101 | return datetime.datetime.now() |
||
| 102 | |||
| 103 | def today() -> datetime: |
||
| 104 | return datetime.datetime.today() |
||
| 105 | |||
| 106 | def mkdate(s: str) -> datetime: |
||
| 107 | return datetime.strptime(s, "%Y-%m-%d") |
||
| 108 | |||
| 109 | def this_year(s: str) -> datetime: |
||
| 110 | return datetime.strptime(s, "%Y") |
||
| 111 | |||
| 112 | def year_range(year: int) -> Tuple[datetime.datetime, datetime.datetime]: |
||
| 113 | return ( |
||
| 114 | datetime(year, 1, 1, 0, 0, 0, 0), |
||
| 115 | datetime(year, 12, 31, 23, 59, 59, 999) |
||
| 116 | ) |
||
| 117 | |||
| 118 | @shift_infix |
||
| 119 | def approxeq(a, b): |
||
| 120 | """This takes 1e-09 * max(abs(a), abs(b)), which may not be appropriate.""" |
||
| 121 | """Example: 5 <<approxeq>> 5.000000000000001""" |
||
| 122 | return abs(a - b) < 1e-09 * max(abs(a), abs(b)) |
||
| 123 | |||
| 124 | class TomlData: |
||
| 125 | """A better TOML data structure than a plain dict. |
||
| 126 | Usage examples: |
||
| 127 | data = TomlData({'x': {'y': {'z': 155}}}) |
||
| 128 | print(data['x.y.z']) # prints 155 |
||
| 129 | data.sub('x.y') # returns a new TomlData for {'z': 155} |
||
| 130 | data.nested_keys() # returns all keys and sub-keys |
||
| 131 | """ |
||
| 132 | def __init__(self, top_level_item: Dict[str, object]): |
||
| 133 | assert top_level_item is not None |
||
| 134 | self.top = top_level_item |
||
| 135 | |||
| 136 | def __str__(self) -> str: |
||
| 137 | return repr(self) |
||
| 138 | def __repr__(self) -> str: |
||
| 139 | return "TomlData({})".format(str(self.top)) |
||
| 140 | |||
| 141 | def __getitem__(self, key: str) -> Dict[str, object]: |
||
| 142 | return self.sub(key).top |
||
| 143 | |||
| 144 | def __contains__(self, key: str) -> bool: |
||
| 145 | try: |
||
| 146 | self.sub(key) |
||
| 147 | return True |
||
| 148 | except AttributeError: return False |
||
| 149 | |||
| 150 | def get_str(self, key: str) -> str: |
||
| 151 | return str(self.__as(key, str)) |
||
| 152 | |||
| 153 | def get_int(self, key: str) -> int: |
||
| 154 | # noinspection PyTypeChecker |
||
| 155 | return int(self.__as(key, int)) |
||
| 156 | |||
| 157 | def get_bool(self, key: str) -> int: |
||
| 158 | # noinspection PyTypeChecker |
||
| 159 | return bool(self.__as(key, bool)) |
||
| 160 | |||
| 161 | def get_str_list(self, key: str) -> List[str]: |
||
| 162 | return self.__as_list(key, str) |
||
| 163 | |||
| 164 | def get_int_list(self, key: str) -> List[int]: |
||
| 165 | return self.__as_list(key, int) |
||
| 166 | |||
| 167 | def get_float_list(self, key: str) -> List[int]: |
||
| 168 | return self.__as_list(key, int) |
||
| 169 | |||
| 170 | def get_bool_list(self, key: str) -> List[int]: |
||
| 171 | return self.__as_list(key, bool) |
||
| 172 | |||
| 173 | def get_float(self, key: str) -> int: |
||
| 174 | # noinspection PyTypeChecker |
||
| 175 | return int(self.__as(key, float)) |
||
| 176 | |||
| 177 | def __as_list(self, key: str, clazz): |
||
| 178 | def to(v): |
||
| 179 | if not isinstance(v, clazz): |
||
| 180 | raise TypeError("{}={} is a {}, not {}".format(key, v, type(v), clazz)) |
||
| 181 | return [to(v) for v in self[key]] |
||
| 182 | |||
| 183 | def __as(self, key: str, clazz): |
||
| 184 | v = self[key] |
||
| 185 | if isinstance(v, clazz): |
||
| 186 | return v |
||
| 187 | else: |
||
| 188 | raise TypeError("{}={} is a {}, not {}".format(key, v, type(v), clazz)) |
||
| 189 | |||
| 190 | def sub(self, key: str): |
||
| 191 | """Returns a new TomlData with its top set to items[1][2]...""" |
||
| 192 | items = key.split('.') |
||
| 193 | item = self.top |
||
| 194 | for i, s in enumerate(items): |
||
| 195 | if s not in item: raise MissingConfigEntry( |
||
| 196 | "{} is not in the TOML; failed at {}" |
||
| 197 | .format(key, '.'.join(items[:i+1])) |
||
| 198 | ) |
||
| 199 | item = item[s] |
||
| 200 | return TomlData(item) |
||
| 201 | |||
| 202 | def items(self) -> ItemsView[str, object]: |
||
| 203 | return self.top.items() |
||
| 204 | |||
| 205 | def keys(self) -> KeysView[str]: |
||
| 206 | return self.top.keys() |
||
| 207 | |||
| 208 | def values(self) -> ValuesView[object]: |
||
| 209 | return self.top.values() |
||
| 210 | |||
| 211 | def nested_keys(self, separator='.') -> Iterator[str]: |
||
| 212 | for lst in self.nested_key_lists(self.top): |
||
| 213 | yield separator.join(lst) |
||
| 214 | |||
| 215 | def nested_key_lists(self, dictionary: Dict[str, object], prefix=None) -> Iterator[List[str]]: |
||
| 216 | |||
| 217 | prefix = prefix[:] if prefix else [] |
||
| 218 | |||
| 219 | if isinstance(dictionary, dict): |
||
| 220 | for key, value in dictionary.items(): |
||
| 221 | |||
| 222 | if isinstance(value, dict): |
||
| 223 | for result in self.nested_key_lists(value, [key] + prefix): yield result |
||
| 224 | else: yield prefix + [key] |
||
| 225 | |||
| 226 | else: yield dictionary |
||
| 227 | |||
| 228 | |||
| 229 | def git_commit_hash(git_repo_dir: str='.') -> str: |
||
| 230 | """Gets the hex of the most recent Git commit hash in git_repo_dir.""" |
||
| 231 | p = Popen(['git', 'rev-parse', 'HEAD'], stdout=PIPE, cwd=git_repo_dir) |
||
| 232 | (out, err) = p.communicate() |
||
| 233 | exit_code = p.wait() |
||
| 234 | if exit_code != 0: raise ValueError("Got nonzero exit code {} from git rev-parse".format(exit_code)) |
||
| 235 | return out.decode('utf-8').rstrip() |
||
| 236 | |||
| 237 | @deprecated(reason="Use klgists.common.flexible_logger instead.") |
||
| 238 | def init_logger( |
||
| 239 | log_path: Optional[str]=None, |
||
| 240 | format_str: str='%(asctime)s %(levelname)-8s: %(message)s', |
||
| 241 | to_std: bool=True, |
||
| 242 | child_logger_name: Optional[str]=None, |
||
| 243 | std_level = logging.INFO, |
||
| 244 | file_level = logging.DEBUG |
||
| 245 | ): |
||
| 246 | """Initializes a logger that can write to a log file and/or stdout.""" |
||
| 247 | |||
| 248 | if log_path is not None and not os.path.exists(os.path.dirname(log_path)): |
||
| 249 | os.mkdir(os.path.dirname(log_path)) |
||
| 250 | |||
| 251 | if child_logger_name is None: |
||
| 252 | logger = logging.getLogger() |
||
| 253 | else: |
||
| 254 | logger = logging.getLogger(child_logger_name) |
||
| 255 | logger.setLevel(logging.NOTSET) |
||
| 256 | |||
| 257 | formatter = logging.Formatter(format_str) |
||
| 258 | |||
| 259 | if log_path is not None: |
||
| 260 | handler = logging.FileHandler(log_path, encoding='utf-8') |
||
| 261 | handler.setLevel(file_level) |
||
| 262 | handler.setFormatter(formatter) |
||
| 263 | logger.addHandler(handler) |
||
| 264 | |||
| 265 | if to_std: |
||
| 266 | handler = logging.StreamHandler() |
||
| 267 | handler.setLevel(std_level) |
||
| 268 | handler.setFormatter(formatter) |
||
| 269 | logger.addHandler(handler) |
||
| 270 | |||
| 271 | import datetime |
||
| 272 | def format_time(time: datetime) -> str: |
||
| 273 | """Standard timestamp format. Ex: 2016-05-02_22_35_56.""" |
||
| 274 | return time.strftime("%Y-%m-%d_%H-%M-%S") |
||
| 275 | |||
| 276 | def timestamp() -> str: |
||
| 277 | """Standard timestamp of time now. Ex: 2016-05-02_22_35_56.""" |
||
| 278 | return format_time(datetime.datetime.now()) |
||
| 279 | |||
| 280 | def timestamp_path(path: str) -> str: |
||
| 281 | """Standard way to label a file path with a timestamp.""" |
||
| 282 | return "{}-{}".format(path, timestamp()) |
||
| 283 | |||
| 284 | |||
| 285 | def nice_time(n_ms: int) -> str: |
||
| 286 | length = datetime.datetime(1, 1, 1) + datetime.timedelta(milliseconds=n_ms) |
||
| 287 | if n_ms < 1000 * 60 * 60 * 24: |
||
| 288 | return "{}h, {}m, {}s".format(length.hour, length.minute, length.second) |
||
| 289 | else: |
||
| 290 | return "{}d, {}h, {}m, {}s".format(length.day, length.hour, length.minute, length.second) |
||
| 291 | |||
| 292 | |||
| 293 | def parse_local_iso_datetime(z: str) -> datetime: |
||
| 294 | return datetime.datetime.strptime(z, '%Y-%m-%dT%H:%M:%S.%f') |
||
| 295 | |||
| 296 | |||
| 297 | logger = logging.getLogger(__name__) |
||
| 298 | |||
| 299 | class PipeType(Enum): |
||
| 300 | STDOUT = 1 |
||
| 301 | STDERR = 2 |
||
| 302 | |||
| 303 | def _disp(out, ell, name): |
||
| 304 | out = out.strip() |
||
| 305 | if '\n' in out: |
||
| 306 | ell(name + ":\n<<=====\n" + out + '\n=====>>') |
||
| 307 | elif len(out) > 0: |
||
| 308 | ell(name + ": <<===== " + out + " =====>>") |
||
| 309 | else: |
||
| 310 | ell(name + ": <no output>") |
||
| 311 | |||
| 312 | |||
| 313 | def _log(out, err, ell): |
||
| 314 | _disp(out, ell, "stdout") |
||
| 315 | _disp(err, ell, "stderr") |
||
| 316 | |||
| 317 | |||
| 318 | def smart_log_callback(source, line, prefix: str = '') -> None: |
||
| 319 | line = line.decode('utf-8') |
||
| 320 | View Code Duplication | if line.startswith('FATAL:'): |
|
| 321 | logger.fatal(prefix + line) |
||
| 322 | elif line.startswith('ERROR:'): |
||
| 323 | logger.error(prefix + line) |
||
| 324 | elif line.startswith('WARNING:'): |
||
| 325 | logger.warning(prefix + line) |
||
| 326 | elif line.startswith('INFO:'): |
||
| 327 | logger.info(prefix + line) |
||
| 328 | elif line.startswith('DEBUG:'): |
||
| 329 | logger.debug(prefix + line) |
||
| 330 | else: |
||
| 331 | logger.debug(prefix + line) |
||
| 332 | |||
| 333 | |||
| 334 | def _reader(pipe_type, pipe, queue): |
||
| 335 | try: |
||
| 336 | with pipe: |
||
| 337 | for line in iter(pipe.readline, b''): |
||
| 338 | queue.put((pipe_type, line)) |
||
| 339 | finally: |
||
| 340 | queue.put(None) |
||
| 341 | |||
| 342 | def stream_cmd_call(cmd: List[str], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell_cmd: str=None, cwd: Optional[str] = None, timeout_secs: Optional[float] = None, log_callback: Callable[[PipeType, bytes], None] = None, bufsize: Optional[int] = None) -> None: |
||
| 343 | """Calls an external command, waits, and throws a ResourceError for nonzero exit codes. |
||
| 344 | Returns (stdout, stderr). |
||
| 345 | The user can optionally provide a shell to run the command with, e.g. "powershell.exe" |
||
| 346 | """ |
||
| 347 | if log_callback is None: |
||
| 348 | log_callback = smart_log_callback |
||
| 349 | cmd = [str(p) for p in cmd] |
||
| 350 | if shell_cmd: |
||
| 351 | cmd = [shell_cmd] + cmd |
||
| 352 | logger.debug("Streaming '{}'".format(' '.join(cmd))) |
||
| 353 | |||
| 354 | p = subprocess.Popen(cmd, stdout=PIPE, stderr=PIPE, cwd=cwd, bufsize=bufsize) |
||
| 355 | try: |
||
| 356 | q = Queue() |
||
| 357 | Thread(target=_reader, args=[PipeType.STDOUT, p.stdout, q]).start() |
||
| 358 | Thread(target=_reader, args=[PipeType.STDERR, p.stderr, q]).start() |
||
| 359 | for _ in range(2): |
||
| 360 | for source, line in iter(q.get, None): |
||
| 361 | log_callback(source, line) |
||
| 362 | exit_code = p.wait(timeout=timeout_secs) |
||
| 363 | finally: |
||
| 364 | p.kill() |
||
| 365 | if exit_code != 0: |
||
| 366 | raise ResourceError("Got nonzero exit code {} from '{}'".format(exit_code, ' '.join(cmd)), cmd, exit_code, '<<unknown>>', '<<unknown>>') |
||
| 367 | |||
| 368 | |||
| 369 | def wrap_cmd_call(cmd: List[str], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell_cmd: str=None, cwd: Optional[str] = None, timeout_secs: Optional[float] = None) -> (str, str): |
||
| 370 | """Calls an external command, waits, and throws a ResourceError for nonzero exit codes. |
||
| 371 | Returns (stdout, stderr). |
||
| 372 | The user can optionally provide a shell to run the command with, e.g. "powershell.exe" |
||
| 373 | """ |
||
| 374 | cmd = [str(p) for p in cmd] |
||
| 375 | if shell_cmd: |
||
| 376 | cmd = [shell_cmd] + cmd |
||
| 377 | logger.debug("Calling '{}'".format(' '.join(cmd))) |
||
| 378 | p = subprocess.Popen(cmd, stdout=stdout, stderr=stderr, cwd=cwd) |
||
| 379 | out, err, exit_code = None, None, None |
||
| 380 | try: |
||
| 381 | (out, err) = p.communicate(timeout=timeout_secs) |
||
| 382 | out = out.decode('utf-8') |
||
| 383 | err = err.decode('utf-8') |
||
| 384 | exit_code = p.wait(timeout=timeout_secs) |
||
| 385 | except Exception as e: |
||
| 386 | _log(out, err, logger.warning) |
||
| 387 | raise e |
||
| 388 | finally: |
||
| 389 | p.kill() |
||
| 390 | if exit_code != 0: |
||
| 391 | _log(out, err, logger.warning) |
||
| 392 | raise ResourceError("Got nonzero exit code {} from '{}'".format(exit_code, ' '.join(cmd)), cmd, exit_code, out, err) |
||
| 393 | _log(out, err, logger.debug) |
||
| 394 | return out, err |
||
| 395 | |||
| 396 | def look(obj: object, attrs: str) -> any: |
||
| 397 | if not isinstance(attrs, str) and isinstance(attrs, Iterable): attrs = '.'.join(attrs) |
||
| 398 | try: |
||
| 399 | return operator.attrgetter(attrs)(obj) |
||
| 400 | except AttributeError: return None |
||
| 401 | |||
| 402 | def flatmap(func, *iterable): |
||
| 403 | return chain.from_iterable(map(func, *iterable)) |
||
| 404 | |||
| 405 | def flatten(*iterable): |
||
| 406 | return list(chain.from_iterable(iterable)) |
||
| 407 | |||
| 408 | class DevNull: |
||
| 409 | def write(self, msg): pass |
||
| 410 | |||
| 411 | pjoin = os.path.join |
||
| 412 | pexists = os.path.exists |
||
| 413 | pdir = os.path.isdir |
||
| 414 | pfile = os.path.isfile |
||
| 415 | pis_dir = os.path.isdir |
||
| 416 | fsize = os.path.getsize |
||
| 417 | def pardir(path: str, depth: int=1): |
||
| 418 | for _ in range(-1, depth): |
||
| 419 | path = os.path.dirname(path) |
||
| 420 | return path |
||
| 421 | def grandpardir(path: str): |
||
| 422 | return pardir(path, 2) |
||
| 423 | |||
| 424 | |||
| 425 | T = TypeVar('T') |
||
| 426 | def try_index_of(element: List[T], list_element: T) -> Optional[T]: |
||
| 427 | try: |
||
| 428 | index_element = list_element.index(element) |
||
| 429 | return index_element |
||
| 430 | except ValueError: |
||
| 431 | return None |
||
| 432 | |||
| 433 | def decorator(cls): |
||
| 434 | return cls |
||
| 435 | |||
| 436 | |||
| 437 | def exists(keep_predicate: Callable[[T], bool], seq: Iterable[T]) -> bool: |
||
| 438 | """Efficient existential quantifier for a filter() predicate. |
||
| 439 | Returns true iff keep_predicate is true for one or more elements.""" |
||
| 440 | for e in seq: |
||
| 441 | if keep_predicate(e): return True # short-circuit |
||
| 442 | return False |
||
| 443 | |||
| 444 | |||
| 445 | def zip_strict(*args): |
||
| 446 | """Same as zip(), but raises an IndexError if the lengths don't match.""" |
||
| 447 | iters = [iter(axis) for axis in args] |
||
| 448 | n_elements = 0 |
||
| 449 | failures = [] |
||
| 450 | while len(failures) == 0: |
||
| 451 | n_elements += 1 |
||
| 452 | values = [] |
||
| 453 | failures = [] |
||
| 454 | for axis, iterator in enumerate(iters): |
||
| 455 | try: |
||
| 456 | values.append(next(iterator)) |
||
| 457 | except StopIteration: |
||
| 458 | failures.append(axis) |
||
| 459 | if len(failures) == 0: |
||
| 460 | yield tuple(values) |
||
| 461 | if len(failures) == 1: |
||
| 462 | raise IndexError("Too few elements ({}) along axis {}".format(n_elements, failures[0])) |
||
| 463 | elif len(failures) < len(iters): |
||
| 464 | raise IndexError("Too few elements ({}) along axes {}".format(n_elements, failures)) |
||
| 465 | |||
| 466 | |||
| 467 | def only(sequence: Iterable[Any]) -> Any: |
||
| 468 | """ |
||
| 469 | Returns either the SINGLE (ONLY) UNIQUE ITEM in the sequence or raises an exception. |
||
| 470 | Each item must have __hash__ defined on it. |
||
| 471 | :param sequence: A list of any items (untyped) |
||
| 472 | :return: The first item the sequence. |
||
| 473 | :raises: ValarLookupError If the sequence is empty |
||
| 474 | :raises: MultipleMatchesError If there is more than one unique item. |
||
| 475 | """ |
||
| 476 | st = set(sequence) |
||
| 477 | if len(st) > 1: |
||
| 478 | raise MultipleMatchesError("More then 1 item in {}".format(sequence)) |
||
| 479 | if len(st) == 0: |
||
| 480 | raise LookupFailedError("Empty sequence") |
||
| 481 | return next(iter(st)) |
||
| 482 | |||
| 483 | |||
| 484 | def read_lines_file(path: str, ignore_comments: bool = False) -> Sequence[str]: |
||
| 485 | """ |
||
| 486 | Returns a list of lines in a file, potentially ignoring comments. |
||
| 487 | :param path: Read the file at this local path |
||
| 488 | :param ignore_comments: Ignore lines beginning with #, excluding whitespace |
||
| 489 | :return: The lines, with surrounding whitespace stripped |
||
| 490 | """ |
||
| 491 | lines = [] |
||
| 492 | with open(path) as f: |
||
| 493 | line = f.readline().strip() |
||
| 494 | if not ignore_comments or not line.startswith('#'): |
||
| 495 | lines.append(line) |
||
| 496 | return lines |
||
| 497 | |||
| 498 | |||
| 499 | def read_properties_file(path: str) -> Mapping[str, str]: |
||
| 500 | """ |
||
| 501 | Reads a .properties file, which is a list of lines with key=value pairs (with an equals sign). |
||
| 502 | Lines beginning with # are ignored. |
||
| 503 | Each line must contain exactly 1 equals sign. |
||
| 504 | :param path: Read the file at this local path |
||
| 505 | :return: A dict mapping keys to values, both with surrounding whitespace stripped |
||
| 506 | """ |
||
| 507 | lines = read_lines_file(path, ignore_comments=False) |
||
| 508 | dct = {} |
||
| 509 | for i, line in enumerate(lines): |
||
| 510 | if line.startswith('#'): continue |
||
| 511 | if line.count('=') != 1: |
||
| 512 | raise ParsingError("Bad line {} in {}".format(i+1, path)) |
||
| 513 | parts = line.split('=') |
||
| 514 | dct[parts[0].strip()] = parts[1].strip() |
||
| 515 | return dct |
||
| 516 | |||
| 517 | |||
| 518 | class Comparable: |
||
| 519 | """A class that's comparable. Just implement __lt__. Credit ot Alex Martelli on https://stackoverflow.com/questions/1061283/lt-instead-of-cmp""" |
||
| 520 | |||
| 521 | def __eq__(self, other): |
||
| 522 | return not self < other and not other < self |
||
| 523 | |||
| 524 | def __ne__(self, other): |
||
| 525 | return self < other or other < self |
||
| 526 | |||
| 527 | def __gt__(self, other): |
||
| 528 | return other < self |
||
| 529 | |||
| 530 | def __ge__(self, other): |
||
| 531 | return not self < other |
||
| 532 | |||
| 533 | def __le__(self, other): |
||
| 534 | return not other < self |
||
| 535 | |||
| 536 | |||
| 537 | def json_serial(obj): |
||
| 538 | """JSON serializer for objects not serializable by default json code. |
||
| 539 | From jgbarah at https://stackoverflow.com/questions/11875770/how-to-overcome-datetime-datetime-not-json-serializable |
||
| 540 | """ |
||
| 541 | if isinstance(obj, (datetime, date)): |
||
| 542 | return obj.isoformat() |
||
| 543 | try: |
||
| 544 | import peewee |
||
| 545 | if isinstance(obj, peewee.Field): |
||
| 546 | return type(obj).__name__ |
||
| 547 | except ImportError: pass |
||
| 548 | raise TypeError("Type %s not serializable" % type(obj)) |
||
| 549 | |||
| 550 | def pretty_dict(dct: dict) -> str: |
||
| 551 | """Returns a pretty-printed dict, complete with indentation. Will fail on non-JSON-serializable datatypes.""" |
||
| 552 | return json.dumps(dct, default=json_serial, sort_keys=True, indent=4) |
||
| 553 | |||
| 554 | def pp_dict(dct: dict) -> None: |
||
| 555 | """Pretty-prints a dict to stdout.""" |
||
| 556 | print(pretty_dict(dct)) |
||
| 557 | |||
| 558 | def pp_size(obj: object) -> None: |
||
| 559 | """Prints to stdout a human-readable string of the memory usage of arbitrary Python objects. Ex: 8M for 8 megabytes.""" |
||
| 560 | print(hsize(sys.getsizeof(obj))) |
||
| 561 | |||
| 562 | def sanitize_str(value: str) -> str: |
||
| 563 | """Removes Unicode control (Cc) characters EXCEPT for tabs (\t), newlines (\n only), line separators (U+2028) and paragraph separators (U+2029).""" |
||
| 564 | return "".join(ch for ch in value if unicodedata.category(ch) != 'Cc' and ch not in {'\t', '\n', '\u2028', '\u2029'}) |
||
| 565 | |||
| 566 | def escape_for_properties(value: Any) -> str: |
||
| 567 | return sanitize_str(str(value).replace('\n', '\u2028')) |
||
| 568 | |||
| 569 | def escape_for_tsv(value: Any) -> str: |
||
| 570 | return sanitize_str(str(value).replace('\n', '\u2028').replace('\t', ' ')) |
||
| 571 | |||
| 572 | class Timeout: |
||
| 573 | def __init__(self, seconds: int = 10, error_message='Timeout'): |
||
| 574 | self.seconds = seconds |
||
| 575 | self.error_message = error_message |
||
| 576 | def handle_timeout(self, signum, frame): |
||
| 577 | raise TimeoutError(self.error_message) |
||
| 578 | def __enter__(self): |
||
| 579 | signal.signal(signal.SIGALRM, self.handle_timeout) |
||
| 580 | signal.alarm(self.seconds) |
||
| 581 | def __exit__(self, type, value, traceback): |
||
| 582 | signal.alarm(0) |
||
| 583 | |||
| 584 | |||
| 585 | class OverwriteChoice(Enum): |
||
| 586 | FAIL = 1 |
||
| 587 | WARN = 2 |
||
| 588 | IGNORE = 3 |
||
| 589 | OVERWRITE = 4 |
||
| 590 | |||
| 591 | def fix_path(path: str) -> str: |
||
| 592 | # ffmpeg won't recognize './' and will simply not write images! |
||
| 593 | # and Python doesn't recognize ~ |
||
| 594 | if '%' in path: raise ValueError( |
||
| 595 | 'For technical limitations (regarding ffmpeg), local paths cannot contain a percent sign (%), but "{}" does'.format(path) |
||
| 596 | ) |
||
| 597 | if path == '~': return os.environ['HOME'] # prevent out of bounds |
||
| 598 | if path.startswith('~'): |
||
| 599 | path = pjoin(os.environ['HOME'], path[2:]) |
||
| 600 | return path.replace('./', '') |
||
| 601 | |||
| 602 | def fix_path_platform_dependent(path: str) -> str: |
||
| 603 | """Modifies path strings to work with Python and external tools. |
||
| 604 | Replaces a beginning '~' with the HOME environment variable. |
||
| 605 | Also accepts either / or \ (but not both) as a path separator in windows. |
||
| 606 | """ |
||
| 607 | path = fix_path(path) |
||
| 608 | # if windows, allow either / or \, but not both |
||
| 609 | if platform.system() == 'Windows': |
||
| 610 | bits = re.split('[/\\\\]', path) |
||
| 611 | return pjoin(*bits).replace(":", ":\\") |
||
| 612 | else: |
||
| 613 | return path |
||
| 614 | |||
| 615 | |||
| 616 | # NTFS doesn't allow these, so let's be safe |
||
| 617 | # Also exclude control characters |
||
| 618 | # 127 is the DEL char |
||
| 619 | _bad_chars = {'/', ':', '<', '>', '"', "'", '\\', '|', '?', '*', chr(127), *{chr(i) for i in range(0, 32)}} |
||
| 620 | assert ' ' not in _bad_chars |
||
| 621 | def _sanitize_bit(p: str) -> str: |
||
| 622 | for b in _bad_chars: p = p.replace(b, '-') |
||
| 623 | return p |
||
| 624 | def pjoin_sanitized_rel(*pieces: Iterable[any]) -> str: |
||
| 625 | """Builds a path from a hierarchy, sanitizing the path by replacing /, :, <, >, ", ', \, |, ?, *, <DEL>, and control characters 0–32 with a hyphen-minus (-). |
||
| 626 | Each input to pjoin_sanitized must refer only to a single directory or file (cannot contain a path separator). |
||
| 627 | This means that you cannot have an absolute path (it would begin with os.path (probably /); use pjoin_sanitized_abs for this. |
||
| 628 | """ |
||
| 629 | return pjoin(*[_sanitize_bit(str(bit)) for bit in pieces]) |
||
| 630 | def pjoin_sanitized_abs(*pieces: Iterable[any]) -> str: |
||
| 631 | """Same as pjoin_sanitized_rel but starts with os.sep (the root directory).""" |
||
| 632 | return pjoin(os.sep, pjoin_sanitized_rel(*pieces)) |
||
| 633 | |||
| 634 | |||
| 635 | def make_dirs(output_dir: str): |
||
| 636 | """Makes a directory if it doesn't exist. |
||
| 637 | May raise a PathIsNotADirError. |
||
| 638 | """ |
||
| 639 | if not os.path.exists(output_dir): |
||
| 640 | os.makedirs(output_dir) |
||
| 641 | elif not os.path.isdir(output_dir): |
||
| 642 | raise PathIsNotADirError("{} already exists and is not a directory".format(output_dir)) |
||
| 643 | |||
| 644 | |||
| 645 | def remake_dirs(output_dir: str): |
||
| 646 | """Makes a directory, remaking it if it already exists. |
||
| 647 | May raise a PathIsNotADirError. |
||
| 648 | """ |
||
| 649 | if os.path.exists(output_dir) and os.path.isdir(output_dir): |
||
| 650 | shutil.rmtree(output_dir) |
||
| 651 | elif os.path.exists(output_dir): |
||
| 652 | raise PathIsNotADirError("{} already exists and is not a directory".format(output_dir)) |
||
| 653 | make_dirs(output_dir) |
||
| 654 | |||
| 655 | |||
| 656 | |||
| 657 | def lines(file_name: str, known_encoding='utf-8') -> Iterator[str]: |
||
| 658 | """Lazily read a text file or gzipped text file, decode, and strip any newline character (\n or \r). |
||
| 659 | If the file name ends with '.gz' or '.gzip', assumes the file is Gzipped. |
||
| 660 | Arguments: |
||
| 661 | known_encoding: Applied only when decoding gzip |
||
| 662 | """ |
||
| 663 | if file_name.endswith('.gz') or file_name.endswith('.gzip'): |
||
| 664 | with io.TextIOWrapper(gzip.open(file_name, 'r'), encoding=known_encoding) as f: |
||
| 665 | for line in f: yield line.rstrip('\n\r') |
||
| 666 | else: |
||
| 667 | with open(file_name, 'r') as f: |
||
| 668 | for line in f: yield line.rstrip('\n\r') |
||
| 669 | |||
| 670 | import dill |
||
| 671 | |||
| 672 | def pkl(data, path: str): |
||
| 673 | with open(path, 'wb') as f: |
||
| 674 | dill.dump(data, f) |
||
| 675 | |||
| 676 | def unpkl(path: str): |
||
| 677 | with open(path, 'rb') as f: |
||
| 678 | return dill.load(f) |
||
| 679 | |||
| 680 | |||
| 681 | def file_from_env_var(var: str) -> str: |
||
| 682 | """ |
||
| 683 | Just returns the path of a file specified in an environment variable, checking that it's a file. |
||
| 684 | Will raise a MissingResourceError error if not set or not a file. |
||
| 685 | :param var: The environment variable name, not including the $ |
||
| 686 | """ |
||
| 687 | if var not in os.environ: |
||
| 688 | raise MissingResourceError('Environment variable ${} is not set'.format(var)) |
||
| 689 | config_file_path = fix_path(os.environ[var]) |
||
| 690 | if not pexists(config_file_path): |
||
| 691 | raise MissingResourceError("{} file {} does not exist".format(var, config_file_path)) |
||
| 692 | if not pfile(config_file_path): |
||
| 693 | raise MissingResourceError("{} file {} is not a file".format(var, config_file_path)) |
||
| 694 | return config_file_path |
||
| 695 | |||
| 696 | |||
| 697 | def is_proper_file(path: str) -> bool: |
||
| 698 | name = os.path.split(path)[1] |
||
| 699 | return len(name) > 0 and name[0] not in {'.', '~', '_'} |
||
| 700 | |||
| 701 | |||
| 702 | def scantree(path: str, follow_symlinks: bool=False) -> Iterator[str]: |
||
| 703 | """List the full path of every file not beginning with '.', '~', or '_' in a directory, recursively. |
||
| 704 | .. deprecated Use scan_for_proper_files, which has a better name |
||
| 705 | """ |
||
| 706 | for entry in os.scandir(path): |
||
| 707 | if entry.is_dir(follow_symlinks=follow_symlinks): |
||
| 708 | yield from scantree(entry.path) |
||
| 709 | elif is_proper_file(entry.path): |
||
| 710 | yield entry.path |
||
| 711 | |||
| 712 | scan_for_proper_files = scantree |
||
| 713 | |||
| 714 | |||
| 715 | def scan_for_files(path: str, follow_symlinks: bool=False) -> Iterator[str]: |
||
| 716 | """ |
||
| 717 | Using a generator, list all files in a directory or one of its subdirectories. |
||
| 718 | Useful for iterating over files in a directory recursively if there are thousands of file. |
||
| 719 | Warning: If there are looping symlinks, follow_symlinks will return an infinite generator. |
||
| 720 | """ |
||
| 721 | for d in os.scandir(path): |
||
| 722 | if d.is_dir(follow_symlinks=follow_symlinks): |
||
| 723 | yield from scan_for_files(d.path) |
||
| 724 | else: |
||
| 725 | yield d.path |
||
| 726 | |||
| 727 | |||
| 728 | def walk_until(some_dir, until: Callable[[str], bool]) -> Iterator[typing.Tuple[str, str, str]]: |
||
| 729 | """Walk but stop recursing after 'until' occurs. |
||
| 730 | Returns files and directories in the same manner as os.walk |
||
| 731 | """ |
||
| 732 | some_dir = some_dir.rstrip(os.path.sep) |
||
| 733 | assert os.path.isdir(some_dir) |
||
| 734 | for root, dirs, files in os.walk(some_dir): |
||
| 735 | yield root, dirs, files |
||
| 736 | if until(root): |
||
| 737 | del dirs[:] |
||
| 738 | |||
| 739 | |||
| 740 | def walk_until_level(some_dir, level: Optional[int]=None) -> Iterator[typing.Tuple[str, str, str]]: |
||
| 741 | """ |
||
| 742 | Walk up to a maximum recursion depth. |
||
| 743 | Returns files and directories in the same manner as os.walk |
||
| 744 | Taken partly from https://stackoverflow.com/questions/7159607/list-directories-with-a-specified-depth-in-python |
||
| 745 | :param some_dir: |
||
| 746 | :param level: Maximum recursion depth, starting at 0 |
||
| 747 | """ |
||
| 748 | some_dir = some_dir.rstrip(os.path.sep) |
||
| 749 | assert os.path.isdir(some_dir) |
||
| 750 | num_sep = some_dir.count(os.path.sep) |
||
| 751 | for root, dirs, files in os.walk(some_dir): |
||
| 752 | yield root, dirs, files |
||
| 753 | num_sep_this = root.count(os.path.sep) |
||
| 754 | if level is None or num_sep + level <= num_sep_this: |
||
| 755 | del dirs[:] |
||
| 756 | |||
| 757 | |||
| 758 | class SubcommandHandler: |
||
| 759 | """A convenient wrapper for a program that uses command-line subcommands. |
||
| 760 | Calls any method that belongs to the target |
||
| 761 | :param parser: Should contain a description and help text, but should NOT contain any arguments. |
||
| 762 | :param target: An object (or type) that contains a method for each subcommand; a dash (-) in the argument is converted to an underscore. |
||
| 763 | :param temp_dir: A temporary directory |
||
| 764 | :param error_handler: Called logging any exception except for KeyboardInterrupt or SystemExit (exceptions in here are ignored) |
||
| 765 | :param cancel_handler: Called after logging a KeyboardInterrupt or SystemExit (exceptions in here are ignored) |
||
| 766 | """ |
||
| 767 | def __init__(self, |
||
| 768 | parser: argparse.ArgumentParser, target: Any, |
||
| 769 | temp_dir: Optional[str] = None, |
||
| 770 | error_handler: Callable[[BaseException], None] = lambda e: None, |
||
| 771 | cancel_handler: Callable[[Union[KeyboardInterrupt, SystemExit]], None] = lambda e: None |
||
| 772 | ) -> None: |
||
| 773 | parser.add_argument('subcommand', help='Subcommand to run') |
||
| 774 | self.parser = parser |
||
| 775 | self.target = target |
||
| 776 | self.temp_dir = temp_dir |
||
| 777 | self.error_handler = error_handler |
||
| 778 | self.cancel_handler = cancel_handler |
||
| 779 | |||
| 780 | |||
| 781 | def run(self, args: List[str]) -> None: |
||
| 782 | |||
| 783 | full_args = self.parser.parse_args(args[1:2]) |
||
| 784 | subcommand = full_args.subcommand.replace('-', '_') |
||
| 785 | |||
| 786 | if not hasattr(self.target, subcommand) and not subcommand.startswith('_'): |
||
| 787 | print(Fore.RED + 'Unrecognized subcommand {}'.format(subcommand)) |
||
| 788 | self.parser.print_help() |
||
| 789 | return |
||
| 790 | |||
| 791 | # clever; from Chase Seibert: https://chase-seibert.github.io/blog/2014/03/21/python-multilevel-argparse.html |
||
| 792 | # use dispatch pattern to invoke method with same name |
||
| 793 | try: |
||
| 794 | if self.temp_dir is not None: |
||
| 795 | if pexists(self.temp_dir) and pdir(self.temp_dir): shutil.rmtree(self.temp_dir) |
||
| 796 | elif pexists(self.temp_dir): raise PathIsNotADirError(self.temp_dir) |
||
| 797 | remake_dirs(self.temp_dir) |
||
| 798 | logger.debug("Created temp dir at {}".format(self.temp_dir)) |
||
| 799 | getattr(self.target, subcommand)() |
||
| 800 | except NaturalExpectedError as e: |
||
| 801 | pass # ignore totally |
||
| 802 | except KeyboardInterrupt as e: |
||
| 803 | try: |
||
| 804 | logger.fatal("Received cancellation signal", exc_info=True) |
||
| 805 | self.cancel_handler(e) |
||
| 806 | except BaseException: pass |
||
| 807 | raise e |
||
| 808 | except SystemExit as e: |
||
| 809 | try: |
||
| 810 | logger.fatal("Received system exit signal", exc_info=True) |
||
| 811 | self.cancel_handler(e) |
||
| 812 | except BaseException: pass |
||
| 813 | raise e |
||
| 814 | except BaseException as e: |
||
| 815 | try: |
||
| 816 | logger.fatal("{} failed!".format(self.parser.prog), exc_info=True) |
||
| 817 | self.error_handler(e) |
||
| 818 | except BaseException: pass |
||
| 819 | raise e |
||
| 820 | finally: |
||
| 821 | if self.temp_dir is not None: |
||
| 822 | if pexists(self.temp_dir): |
||
| 823 | logger.debug("Deleted temp dir at {}".format(self.temp_dir)) |
||
| 824 | shutil.rmtree(self.temp_dir) |
||
| 825 | try: |
||
| 826 | os.remove(self.temp_dir) |
||
| 827 | except IOError: pass |
||
| 828 | |||
| 830 |