1
|
|
|
"""Filters for usage in pathfinding.""" |
2
|
|
|
|
3
|
1 |
|
from dataclasses import dataclass |
4
|
1 |
|
from typing import Any, Callable, Iterable, Union |
5
|
|
|
|
6
|
1 |
|
EdgeData = tuple[str, str, dict] |
7
|
1 |
|
EdgeList = Iterable[EdgeData] |
8
|
|
|
|
9
|
1 |
|
EdgeDataExtractor = Callable[[EdgeData, Any], Any] |
10
|
|
|
|
11
|
1 |
|
@dataclass |
12
|
1 |
|
class UseValIfNone: |
13
|
1 |
|
processor: Callable[[EdgeData], Any] |
14
|
|
|
|
15
|
1 |
|
def __call__(self, edge: EdgeData, val: Any): |
16
|
1 |
|
result = self.processor(edge) |
17
|
1 |
|
if result is None: |
18
|
1 |
|
return val |
19
|
1 |
|
return result |
20
|
|
|
|
21
|
1 |
|
@dataclass |
22
|
1 |
|
class UseDefaultIfNone: |
23
|
1 |
|
processor: Callable[[EdgeData], Any] |
24
|
1 |
|
default: Any |
25
|
|
|
|
26
|
1 |
|
def __call__(self, edge: EdgeData, _): |
27
|
1 |
|
result = self.processor(edge) |
28
|
1 |
|
if result is None: |
29
|
1 |
|
return self.default |
30
|
1 |
|
return result |
31
|
|
|
|
32
|
1 |
|
@dataclass |
33
|
1 |
|
class ProcessEdgeAttribute: |
34
|
1 |
|
attribute: str |
35
|
1 |
|
processor: Callable[[Any], Any] = None |
36
|
|
|
|
37
|
1 |
|
def __call__(self, edge: EdgeData): |
38
|
1 |
|
result = edge[2].get(self.attribute) |
39
|
1 |
|
try: |
40
|
1 |
|
if self.processor: |
41
|
1 |
|
return self.processor(result) |
42
|
|
|
except TypeError as err: |
43
|
|
|
raise TypeError(f'({edge[0]}, {edge[1]}) has invalid metadata "{self.attribute}" with value {result!r}: {err}') from err |
44
|
1 |
|
return result |
45
|
|
|
|
46
|
1 |
|
class EdgeFilter: |
47
|
|
|
|
48
|
1 |
|
def __init__( |
49
|
|
|
self, |
50
|
|
|
comparator: Callable[[Any, Any], bool], |
51
|
|
|
get_func: Union[EdgeDataExtractor, str], |
52
|
|
|
preprocessor: Callable[[Any], Any] = None |
53
|
|
|
): |
54
|
1 |
|
self.comparator = comparator |
55
|
|
|
|
56
|
1 |
|
if isinstance(get_func, str): |
57
|
1 |
|
get_func = UseValIfNone(ProcessEdgeAttribute(get_func)) |
58
|
|
|
|
59
|
1 |
|
self.get_func = get_func |
60
|
|
|
|
61
|
1 |
|
self.preprocessor = preprocessor |
62
|
|
|
|
63
|
1 |
|
def __call__( |
64
|
|
|
self, |
65
|
|
|
value: Any, |
66
|
|
|
items: EdgeList |
67
|
|
|
) -> EdgeList: |
68
|
|
|
"""Apply the filter given the items and value.""" |
69
|
|
|
|
70
|
|
|
# Preprocess the value |
71
|
1 |
|
if self.preprocessor: |
72
|
1 |
|
value = self.preprocessor(value) |
73
|
|
|
|
74
|
|
|
# Run filter |
75
|
1 |
|
return filter( |
76
|
|
|
lambda edge: |
77
|
|
|
self.comparator( |
78
|
|
|
self.get_func(edge, value), |
79
|
|
|
value |
80
|
|
|
), |
81
|
|
|
items |
82
|
|
|
) |
83
|
|
|
|
84
|
1 |
|
@dataclass |
85
|
1 |
|
class TypeCheckPreprocessor: |
86
|
1 |
|
types: Union[type, tuple[type]] |
87
|
|
|
|
88
|
1 |
|
def __call__( |
89
|
|
|
self, |
90
|
|
|
value: Any |
91
|
|
|
): |
92
|
1 |
|
if not isinstance(value, self.types): |
93
|
1 |
|
raise TypeError(f"Expected type: {self.types}") |
94
|
1 |
|
return value |
95
|
|
|
|
96
|
1 |
|
@dataclass |
97
|
1 |
|
class TypeDifferentiatedProcessor: |
98
|
1 |
|
preprocessors: dict[Union[type, tuple[type]], Callable[[Any], Any]] = None |
99
|
|
|
|
100
|
1 |
|
def __call__( |
101
|
|
|
self, |
102
|
|
|
value: Any |
103
|
|
|
): |
104
|
1 |
|
for expected_types, processor in self.preprocessors.items(): |
105
|
1 |
|
if isinstance(value, expected_types): |
106
|
1 |
|
if processor: |
107
|
1 |
|
return processor(value) |
108
|
1 |
|
return value |
109
|
|
|
raise TypeError(f"Expected types: {list(self.preprocessors.keys())}") |
110
|
|
|
|