|
1
|
|
|
import attr |
|
2
|
|
|
|
|
3
|
|
|
from artificial_artwork.utils.subclass_registry import SubclassRegistry |
|
4
|
|
|
from .termination_condition_interface import TerminationConditionInterface |
|
5
|
|
|
|
|
6
|
|
|
# TODO: learn how to use the Abstract class implementing a generic interface |
|
7
|
|
|
# and then inherit from the Abstract class |
|
8
|
|
|
# T = TypeVar('T') |
|
9
|
|
|
|
|
10
|
|
|
# class AbstractTerminationCondition(TerminationConditionInterface, Generic[T]): pass |
|
11
|
|
|
|
|
12
|
|
|
|
|
13
|
|
|
class TerminationFactory(metaclass=SubclassRegistry): |
|
14
|
|
|
pass |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
@attr.s |
|
18
|
|
|
@TerminationFactory.register_as_subclass('max-iterations') |
|
19
|
|
|
class MaxIterations(TerminationConditionInterface[int]): |
|
20
|
|
|
max_iterations: int = attr.ib() |
|
21
|
|
|
|
|
22
|
|
|
def satisfied(self, iterations: int) -> bool: |
|
23
|
|
|
return self.max_iterations <= iterations |
|
24
|
|
|
|
|
25
|
|
|
@attr.s |
|
26
|
|
|
@TerminationFactory.register_as_subclass('convergence') |
|
27
|
|
|
class Convergence(TerminationConditionInterface[float]): |
|
28
|
|
|
min_improvement: float = attr.ib() |
|
29
|
|
|
|
|
30
|
|
|
def satisfied(self, last_loss_improvement: float) -> bool: |
|
31
|
|
|
return last_loss_improvement < self.min_improvement |
|
32
|
|
|
|
|
33
|
|
|
@attr.s |
|
34
|
|
|
@TerminationFactory.register_as_subclass('time-limit') |
|
35
|
|
|
class TimeLimit(TerminationConditionInterface[float]): |
|
36
|
|
|
time_limit: float = attr.ib() |
|
37
|
|
|
|
|
38
|
|
|
def satisfied(self, duration: float) -> bool: |
|
39
|
|
|
return self.time_limit <= duration |
|
40
|
|
|
|
|
41
|
|
|
|
|
42
|
|
|
class TerminationConditionFacility: |
|
43
|
|
|
class_registry: SubclassRegistry = TerminationFactory |
|
44
|
|
|
|
|
45
|
|
|
@classmethod |
|
46
|
|
|
def create(cls, termination_condition_type: str, *args, **kwargs) -> TerminationConditionInterface: |
|
47
|
|
|
return cls.class_registry.create(termination_condition_type, *args, **kwargs) |
|
48
|
|
|
|