Total Complexity | 52 |
Total Lines | 325 |
Duplicated Lines | 17.23 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like gradient_free_optimizers._stopping_conditions often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import time |
||
2 | import logging |
||
3 | from abc import ABC, abstractmethod |
||
4 | from dataclasses import dataclass, field |
||
5 | from typing import List, Optional, Dict, Any |
||
6 | import numpy as np |
||
7 | |||
8 | |||
9 | @dataclass |
||
10 | class StoppingContext: |
||
11 | """ |
||
12 | Encapsulates all relevant data for stopping condition evaluation. |
||
13 | This creates a clear contract for what data stopping conditions can access. |
||
14 | """ |
||
15 | |||
16 | iteration: int |
||
17 | score_current: float |
||
18 | score_best: float |
||
19 | score_history: List[float] |
||
20 | start_time: float |
||
21 | current_time: float |
||
22 | |||
23 | @property |
||
24 | def elapsed_time(self) -> float: |
||
25 | """Time elapsed since optimization started.""" |
||
26 | return self.current_time - self.start_time |
||
27 | |||
28 | @property |
||
29 | def iterations_since_improvement(self) -> int: |
||
30 | """Number of iterations since the best score was found.""" |
||
31 | if not self.score_history: |
||
32 | return 0 |
||
33 | |||
34 | best_score_idx = np.argmax(self.score_history) |
||
35 | return len(self.score_history) - best_score_idx - 1 |
||
36 | |||
37 | |||
38 | class StoppingCondition(ABC): |
||
39 | """ |
||
40 | Abstract base class for all stopping conditions. |
||
41 | Each condition is responsible for a single stopping criterion. |
||
42 | """ |
||
43 | |||
44 | def __init__(self, name: str): |
||
45 | self.name = name |
||
46 | self.triggered = False |
||
47 | self.trigger_reason = "" |
||
48 | self.logger = logging.getLogger(f"{__name__}.{self.name}") |
||
49 | |||
50 | @abstractmethod |
||
51 | def should_stop(self, context: StoppingContext) -> bool: |
||
52 | """Check if the optimization should stop based on this condition.""" |
||
53 | pass |
||
54 | |||
55 | @abstractmethod |
||
56 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
57 | """Return detailed information for debugging purposes.""" |
||
58 | pass |
||
59 | |||
60 | def reset(self): |
||
61 | """Reset the condition to its initial state.""" |
||
62 | self.triggered = False |
||
63 | self.trigger_reason = "" |
||
64 | |||
65 | |||
66 | View Code Duplication | class TimeExceededCondition(StoppingCondition): |
|
|
|||
67 | """Stops when maximum time limit is exceeded.""" |
||
68 | |||
69 | def __init__(self, max_time: Optional[float]): |
||
70 | super().__init__("TimeExceeded") |
||
71 | self.max_time = max_time |
||
72 | |||
73 | def should_stop(self, context: StoppingContext) -> bool: |
||
74 | if self.max_time is None: |
||
75 | return False |
||
76 | |||
77 | if context.elapsed_time > self.max_time: |
||
78 | self.triggered = True |
||
79 | self.trigger_reason = f"Time limit exceeded: {context.elapsed_time:.2f}s > {self.max_time:.2f}s" |
||
80 | self.logger.info(self.trigger_reason) |
||
81 | return True |
||
82 | return False |
||
83 | |||
84 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
85 | return { |
||
86 | "condition": self.name, |
||
87 | "max_time": self.max_time, |
||
88 | "elapsed_time": context.elapsed_time, |
||
89 | "time_remaining": ( |
||
90 | self.max_time - context.elapsed_time if self.max_time else None |
||
91 | ), |
||
92 | "triggered": self.triggered, |
||
93 | "reason": self.trigger_reason, |
||
94 | } |
||
95 | |||
96 | |||
97 | View Code Duplication | class ScoreExceededCondition(StoppingCondition): |
|
98 | """Stops when target score is reached or exceeded.""" |
||
99 | |||
100 | def __init__(self, max_score: Optional[float]): |
||
101 | super().__init__("ScoreExceeded") |
||
102 | self.max_score = max_score |
||
103 | |||
104 | def should_stop(self, context: StoppingContext) -> bool: |
||
105 | if self.max_score is None: |
||
106 | return False |
||
107 | |||
108 | if context.score_best >= self.max_score: |
||
109 | self.triggered = True |
||
110 | self.trigger_reason = f"Target score reached: {context.score_best:.6f} >= {self.max_score:.6f}" |
||
111 | self.logger.info(self.trigger_reason) |
||
112 | return True |
||
113 | return False |
||
114 | |||
115 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
116 | return { |
||
117 | "condition": self.name, |
||
118 | "max_score": self.max_score, |
||
119 | "current_best_score": context.score_best, |
||
120 | "score_gap": ( |
||
121 | self.max_score - context.score_best if self.max_score else None |
||
122 | ), |
||
123 | "triggered": self.triggered, |
||
124 | "reason": self.trigger_reason, |
||
125 | } |
||
126 | |||
127 | |||
128 | class NoImprovementCondition(StoppingCondition): |
||
129 | """Stops when no improvement is observed for a specified number of iterations.""" |
||
130 | |||
131 | def __init__( |
||
132 | self, |
||
133 | n_iter_no_change: int, |
||
134 | tol_abs: Optional[float] = None, |
||
135 | tol_rel: Optional[float] = None, |
||
136 | ): |
||
137 | super().__init__("NoImprovement") |
||
138 | self.n_iter_no_change = n_iter_no_change |
||
139 | self.tol_abs = tol_abs |
||
140 | self.tol_rel = tol_rel |
||
141 | |||
142 | def should_stop(self, context: StoppingContext) -> bool: |
||
143 | if len(context.score_history) <= self.n_iter_no_change: |
||
144 | return False |
||
145 | |||
146 | iterations_stale = context.iterations_since_improvement |
||
147 | |||
148 | if iterations_stale >= self.n_iter_no_change: |
||
149 | self.triggered = True |
||
150 | self.trigger_reason = f"No improvement for {iterations_stale} iterations" |
||
151 | self.logger.info(self.trigger_reason) |
||
152 | return True |
||
153 | |||
154 | # Check tolerance-based early stopping |
||
155 | first_n = len(context.score_history) - self.n_iter_no_change |
||
156 | scores_before = context.score_history[:first_n] |
||
157 | |||
158 | if not scores_before: |
||
159 | return False |
||
160 | |||
161 | max_score_before = max(scores_before) |
||
162 | current_best = context.score_best |
||
163 | |||
164 | # Absolute tolerance check |
||
165 | if self.tol_abs is not None: |
||
166 | improvement = abs(current_best - max_score_before) |
||
167 | if improvement < self.tol_abs: |
||
168 | self.triggered = True |
||
169 | self.trigger_reason = f"Improvement below absolute tolerance: {improvement:.6f} < {self.tol_abs:.6f}" |
||
170 | self.logger.info(self.trigger_reason) |
||
171 | return True |
||
172 | |||
173 | # Relative tolerance check |
||
174 | if self.tol_rel is not None and max_score_before != 0: |
||
175 | improvement_pct = ( |
||
176 | (current_best - max_score_before) / abs(max_score_before) |
||
177 | ) * 100 |
||
178 | if improvement_pct < self.tol_rel: |
||
179 | self.triggered = True |
||
180 | self.trigger_reason = f"Improvement below relative tolerance: {improvement_pct:.2f}% < {self.tol_rel:.2f}%" |
||
181 | self.logger.info(self.trigger_reason) |
||
182 | return True |
||
183 | |||
184 | return False |
||
185 | |||
186 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
187 | iterations_stale = context.iterations_since_improvement |
||
188 | |||
189 | debug_info = { |
||
190 | "condition": self.name, |
||
191 | "n_iter_no_change": self.n_iter_no_change, |
||
192 | "iterations_since_improvement": iterations_stale, |
||
193 | "tol_abs": self.tol_abs, |
||
194 | "tol_rel": self.tol_rel, |
||
195 | "triggered": self.triggered, |
||
196 | "reason": self.trigger_reason, |
||
197 | } |
||
198 | |||
199 | if len(context.score_history) > self.n_iter_no_change: |
||
200 | first_n = len(context.score_history) - self.n_iter_no_change |
||
201 | scores_before = context.score_history[:first_n] |
||
202 | if scores_before: |
||
203 | max_score_before = max(scores_before) |
||
204 | improvement = context.score_best - max_score_before |
||
205 | debug_info["improvement_abs"] = improvement |
||
206 | if max_score_before != 0: |
||
207 | debug_info["improvement_rel_pct"] = ( |
||
208 | improvement / abs(max_score_before) |
||
209 | ) * 100 |
||
210 | |||
211 | return debug_info |
||
212 | |||
213 | |||
214 | class CompositeStoppingCondition(StoppingCondition): |
||
215 | """Combines multiple stopping conditions with OR logic.""" |
||
216 | |||
217 | def __init__(self, conditions: List[StoppingCondition]): |
||
218 | super().__init__("Composite") |
||
219 | self.conditions = conditions |
||
220 | |||
221 | def should_stop(self, context: StoppingContext) -> bool: |
||
222 | for condition in self.conditions: |
||
223 | if condition.should_stop(context): |
||
224 | self.triggered = True |
||
225 | self.trigger_reason = ( |
||
226 | f"Stopped by {condition.name}: {condition.trigger_reason}" |
||
227 | ) |
||
228 | self.logger.info(self.trigger_reason) |
||
229 | return True |
||
230 | return False |
||
231 | |||
232 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
233 | return { |
||
234 | "condition": self.name, |
||
235 | "triggered": self.triggered, |
||
236 | "reason": self.trigger_reason, |
||
237 | "sub_conditions": [ |
||
238 | condition.get_debug_info(context) for condition in self.conditions |
||
239 | ], |
||
240 | } |
||
241 | |||
242 | def reset(self): |
||
243 | super().reset() |
||
244 | for condition in self.conditions: |
||
245 | condition.reset() |
||
246 | |||
247 | |||
248 | class OptimizationStopper: |
||
249 | """ |
||
250 | Main class for managing optimization stopping conditions. |
||
251 | Provides a clean interface and comprehensive debugging capabilities. |
||
252 | """ |
||
253 | |||
254 | def __init__( |
||
255 | self, |
||
256 | start_time: float, |
||
257 | max_time: Optional[float] = None, |
||
258 | max_score: Optional[float] = None, |
||
259 | early_stopping: Optional[Dict[str, Any]] = None, |
||
260 | ): |
||
261 | self.start_time = start_time |
||
262 | self.conditions: List[StoppingCondition] = [] |
||
263 | self.score_history: List[float] = [] |
||
264 | self.score_best = -np.inf |
||
265 | self.iteration = 0 |
||
266 | self.logger = logging.getLogger(f"{__name__}.OptimizationStopper") |
||
267 | |||
268 | # Build stopping conditions |
||
269 | if max_time is not None: |
||
270 | self.conditions.append(TimeExceededCondition(max_time)) |
||
271 | |||
272 | if max_score is not None: |
||
273 | self.conditions.append(ScoreExceededCondition(max_score)) |
||
274 | |||
275 | if early_stopping is not None: |
||
276 | n_iter = early_stopping.get("n_iter_no_change") |
||
277 | if n_iter is not None: |
||
278 | self.conditions.append( |
||
279 | NoImprovementCondition( |
||
280 | n_iter_no_change=n_iter, |
||
281 | tol_abs=early_stopping.get("tol_abs"), |
||
282 | tol_rel=early_stopping.get("tol_rel"), |
||
283 | ) |
||
284 | ) |
||
285 | |||
286 | self.composite_condition = CompositeStoppingCondition(self.conditions) |
||
287 | |||
288 | def update(self, score_current: float, score_best: float, iteration: int): |
||
289 | """Update the stopper with new optimization state.""" |
||
290 | self.score_history.append(score_current) |
||
291 | self.score_best = score_best |
||
292 | self.iteration = iteration |
||
293 | |||
294 | def should_stop(self) -> bool: |
||
295 | """Check if optimization should stop.""" |
||
296 | context = StoppingContext( |
||
297 | iteration=self.iteration, |
||
298 | score_current=self.score_history[-1] if self.score_history else -np.inf, |
||
299 | score_best=self.score_best, |
||
300 | score_history=self.score_history, |
||
301 | start_time=self.start_time, |
||
302 | current_time=time.time(), |
||
303 | ) |
||
304 | |||
305 | return self.composite_condition.should_stop(context) |
||
306 | |||
307 | def get_debug_info(self) -> Dict[str, Any]: |
||
308 | """Get comprehensive debugging information about stopping conditions.""" |
||
309 | context = StoppingContext( |
||
310 | iteration=self.iteration, |
||
311 | score_current=self.score_history[-1] if self.score_history else -np.inf, |
||
312 | score_best=self.score_best, |
||
313 | score_history=self.score_history, |
||
314 | start_time=self.start_time, |
||
315 | current_time=time.time(), |
||
316 | ) |
||
317 | |||
318 | return self.composite_condition.get_debug_info(context) |
||
319 | |||
320 | def get_stop_reason(self) -> str: |
||
321 | """Get a human-readable reason for why optimization stopped.""" |
||
322 | if self.composite_condition.triggered: |
||
323 | return self.composite_condition.trigger_reason |
||
324 | return "Optimization not stopped by stopper" |
||
325 |