|
1
|
|
|
from abc import ABC |
|
2
|
|
|
from typing import Callable, Protocol, Dict, Any |
|
3
|
|
|
from types import MethodType |
|
4
|
|
|
from .termination_condition.termination_condition_interface import TerminationConditionInterface |
|
5
|
|
|
|
|
6
|
|
|
from .utils.memoize import ObjectsPool |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
class MetricsCapable(Protocol): |
|
10
|
|
|
metrics: Dict[str, Any] |
|
11
|
|
|
|
|
12
|
|
|
__all__ = ['TerminationConditionAdapterFactory'] |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
class AbstractTerminationConditionAdapter(ABC): |
|
16
|
|
|
termination_condition: TerminationConditionInterface |
|
17
|
|
|
update: Callable[[MetricsCapable], None] |
|
18
|
|
|
|
|
19
|
|
|
def __new__(cls, termination_condition, *args, **kwargs): |
|
20
|
|
|
instance = super().__new__(cls, *args, **kwargs) |
|
21
|
|
|
instance.termination_condition = termination_condition |
|
22
|
|
|
instance.update = 1 |
|
23
|
|
|
# instance.update = MethodType(cls._update_callback(cls.adapter_type), instance.update) |
|
24
|
|
|
instance.runtime_state = cls._initial_state_callback(cls.adapter_type)() |
|
25
|
|
|
return instance |
|
26
|
|
|
|
|
27
|
|
|
def __init__(self, *args, **kwargs): |
|
28
|
|
|
# TODO move all code in __new__ |
|
29
|
|
|
self.update = MethodType(type(self)._update_callback(type(self).adapter_type), self) |
|
30
|
|
|
# setattr(self, attribute.name, types.MethodType(method, self)) |
|
31
|
|
|
|
|
32
|
|
|
@classmethod |
|
33
|
|
|
def _update_callback(cls, type: str): |
|
34
|
|
|
def update(self, *args, **kwargs) -> None: |
|
35
|
|
|
self.runtime_state = args[0].state.metrics[cls.mapping[type]['key_name']] |
|
36
|
|
|
return update |
|
37
|
|
|
|
|
38
|
|
|
@classmethod |
|
39
|
|
|
def _initial_state_callback(cls, type: str): |
|
40
|
|
|
def get_initial_state(): |
|
41
|
|
|
return cls.mapping[type]['state'] |
|
42
|
|
|
return get_initial_state |
|
43
|
|
|
|
|
44
|
|
|
@property |
|
45
|
|
|
def satisfied(self): |
|
46
|
|
|
return self.termination_condition.satisfied(self.runtime_state) |
|
47
|
|
|
|
|
48
|
|
|
|
|
49
|
|
|
# Define Metaclass |
|
50
|
|
|
class TerminationConditionAdapterType(type): |
|
51
|
|
|
|
|
52
|
|
|
def __new__(mcs, *args, **kwargs): |
|
53
|
|
|
# termination_condition_adapter_class = super().__new__(mcs, 'TerminationConditionAdapter', (AbstractTerminationConditionAdapter,), {}) |
|
54
|
|
|
termination_condition_adapter_class = type('TerminationConditionAdapter', (AbstractTerminationConditionAdapter,), {}) |
|
55
|
|
|
termination_condition_adapter_class.adapter_type = args[0] |
|
56
|
|
|
|
|
57
|
|
|
# The (outer) keys are usable by client code to select termination condition |
|
58
|
|
|
# Each (inner) 'key_name' points to the name to use to query the subject dict |
|
59
|
|
|
# Each initializer is a callback to use to initialize the 'runtime_state' attribute [per termination condition (adapter)] |
|
60
|
|
|
termination_condition_adapter_class.mapping = { |
|
61
|
|
|
'max-iterations': {'key_name': 'iterations', 'state': 0}, |
|
62
|
|
|
'convergence': {'key_name': 'cost', 'state': float('inf')}, |
|
63
|
|
|
'time-limit': {'key_name': 'duration', 'state': 0}, |
|
64
|
|
|
} |
|
65
|
|
|
return termination_condition_adapter_class |
|
66
|
|
|
|
|
67
|
|
|
# Investigate usage of __init__ to verify the above behaviour can be replicated with __init__ |
|
68
|
|
|
|
|
69
|
|
|
class TerminationConditionAdapterClassFactory: |
|
70
|
|
|
"""Acts as a proxy to the the 'class maker' function by returning a memoized class.""" |
|
71
|
|
|
classes_pool = ObjectsPool.new_empty(TerminationConditionAdapterType) |
|
72
|
|
|
|
|
73
|
|
|
@classmethod |
|
74
|
|
|
def create(cls, adapter_type: str): |
|
75
|
|
|
return cls.classes_pool.get_object(adapter_type) |
|
76
|
|
|
|
|
77
|
|
|
|
|
78
|
|
|
class TerminationConditionAdapterFactory: |
|
79
|
|
|
|
|
80
|
|
|
@classmethod |
|
81
|
|
|
def create(cls, adapter_type: str, *args, **kwargs): |
|
82
|
|
|
dynamic_class = TerminationConditionAdapterClassFactory.create(adapter_type) |
|
83
|
|
|
return dynamic_class(*args, **kwargs) |
|
84
|
|
|
|