1
|
|
|
"""Filters for usage in pathfinding.""" |
2
|
|
|
|
3
|
1 |
|
from dataclasses import dataclass |
4
|
1 |
|
from typing import Any, Callable, Iterable, Union |
5
|
|
|
|
6
|
1 |
|
EdgeData = tuple[str, str, dict] |
7
|
1 |
|
EdgeList = Iterable[EdgeData] |
8
|
|
|
|
9
|
1 |
|
EdgeDataExtractor = Callable[[EdgeData, Any], Any] |
10
|
|
|
|
11
|
1 |
|
@dataclass |
12
|
1 |
|
class UseValIfNone: |
13
|
1 |
|
processor: Callable[[EdgeData], Any] |
14
|
|
|
|
15
|
1 |
|
def __call__(self, edge: EdgeData, val: Any): |
16
|
1 |
|
result = self.processor(edge) |
17
|
1 |
|
if result is None: |
18
|
1 |
|
return val |
19
|
1 |
|
return result |
20
|
|
|
|
21
|
1 |
|
@dataclass |
22
|
1 |
|
class UseDefaultIfNone: |
23
|
1 |
|
processor: Callable[[EdgeData], Any] |
24
|
1 |
|
default: Any |
25
|
|
|
|
26
|
1 |
|
def __call__(self, edge: EdgeData, _): |
27
|
|
|
result = self.processor(edge) |
28
|
|
|
if result is None: |
29
|
|
|
return self.default |
30
|
|
|
return result |
31
|
|
|
|
32
|
1 |
|
@dataclass |
33
|
1 |
|
class ProcessEdgeAttribute: |
34
|
1 |
|
attribute: str |
35
|
1 |
|
processor: Callable[[Any], Any] = None |
36
|
|
|
|
37
|
1 |
|
def __call__(self, edge: EdgeData): |
38
|
1 |
|
result = edge[2].get(self.attribute) |
39
|
1 |
|
if self.processor: |
40
|
1 |
|
return self.processor(result) |
41
|
1 |
|
return result |
42
|
|
|
|
43
|
1 |
|
class EdgeFilter: |
44
|
|
|
|
45
|
1 |
|
def __init__( |
46
|
|
|
self, |
47
|
|
|
comparator: Callable[[Any, Any], bool], |
48
|
|
|
get_func: Union[EdgeDataExtractor, str], |
49
|
|
|
preprocessor: Callable[[Any], Any] = None |
50
|
|
|
): |
51
|
1 |
|
self.comparator = comparator |
52
|
|
|
|
53
|
1 |
|
if isinstance(get_func, str): |
54
|
1 |
|
get_func = UseValIfNone(ProcessEdgeAttribute(get_func)) |
55
|
|
|
|
56
|
1 |
|
self.get_func = get_func |
57
|
|
|
|
58
|
1 |
|
self.preprocessor = preprocessor |
59
|
|
|
|
60
|
1 |
|
def __call__( |
61
|
|
|
self, |
62
|
|
|
value: Any, |
63
|
|
|
items: EdgeList |
64
|
|
|
) -> EdgeList: |
65
|
|
|
"""Apply the filter given the items and value.""" |
66
|
|
|
|
67
|
|
|
# Preprocess the value |
68
|
1 |
|
if self.preprocessor: |
69
|
1 |
|
value = self.preprocessor(value) |
70
|
|
|
|
71
|
|
|
# Run filter |
72
|
1 |
|
return filter( |
73
|
|
|
lambda edge: |
74
|
|
|
self.comparator( |
75
|
|
|
self.get_func(edge, value), |
76
|
|
|
value |
77
|
|
|
), |
78
|
|
|
items |
79
|
|
|
) |
80
|
|
|
|
81
|
1 |
|
@dataclass |
82
|
1 |
|
class TypeCheckPreprocessor: |
83
|
1 |
|
types: Union[type, tuple[type]] |
84
|
1 |
|
preprocessor: Callable[[Any], Any] = None |
85
|
|
|
|
86
|
1 |
|
def __call__( |
87
|
|
|
self, |
88
|
|
|
value: Any |
89
|
|
|
): |
90
|
1 |
|
if not isinstance(value, self.types): |
91
|
1 |
|
raise TypeError(f"Expected type: {self.types}") |
92
|
1 |
|
if self.preprocessor: |
93
|
1 |
|
return self.preprocessor(value) |
94
|
1 |
|
return value |
95
|
|
|
|
96
|
1 |
|
@dataclass |
97
|
1 |
|
class TypeDifferentiatedProcessor: |
98
|
1 |
|
preprocessors: dict[Union[type, tuple[type]], Callable[[Any], Any]] = None |
99
|
|
|
|
100
|
1 |
|
def __call__( |
101
|
|
|
self, |
102
|
|
|
value: Any |
103
|
|
|
): |
104
|
1 |
|
for expected_types, processor in self.preprocessors.items(): |
105
|
1 |
|
if isinstance(value, expected_types): |
106
|
1 |
|
if processor: |
107
|
1 |
|
return processor(value) |
108
|
1 |
|
return value |
109
|
|
|
raise TypeError(f"Expected types: {self.preprocessors.keys()}") |
110
|
|
|
|