| Total Complexity | 52 |
| Total Lines | 325 |
| Duplicated Lines | 17.23 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like gradient_free_optimizers._stopping_conditions often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import time |
||
| 2 | import logging |
||
| 3 | from abc import ABC, abstractmethod |
||
| 4 | from dataclasses import dataclass, field |
||
| 5 | from typing import List, Optional, Dict, Any |
||
| 6 | import numpy as np |
||
| 7 | |||
| 8 | |||
| 9 | @dataclass |
||
| 10 | class StoppingContext: |
||
| 11 | """ |
||
| 12 | Encapsulates all relevant data for stopping condition evaluation. |
||
| 13 | This creates a clear contract for what data stopping conditions can access. |
||
| 14 | """ |
||
| 15 | |||
| 16 | iteration: int |
||
| 17 | score_current: float |
||
| 18 | score_best: float |
||
| 19 | score_history: List[float] |
||
| 20 | start_time: float |
||
| 21 | current_time: float |
||
| 22 | |||
| 23 | @property |
||
| 24 | def elapsed_time(self) -> float: |
||
| 25 | """Time elapsed since optimization started.""" |
||
| 26 | return self.current_time - self.start_time |
||
| 27 | |||
| 28 | @property |
||
| 29 | def iterations_since_improvement(self) -> int: |
||
| 30 | """Number of iterations since the best score was found.""" |
||
| 31 | if not self.score_history: |
||
| 32 | return 0 |
||
| 33 | |||
| 34 | best_score_idx = np.argmax(self.score_history) |
||
| 35 | return len(self.score_history) - best_score_idx - 1 |
||
| 36 | |||
| 37 | |||
| 38 | class StoppingCondition(ABC): |
||
| 39 | """ |
||
| 40 | Abstract base class for all stopping conditions. |
||
| 41 | Each condition is responsible for a single stopping criterion. |
||
| 42 | """ |
||
| 43 | |||
| 44 | def __init__(self, name: str): |
||
| 45 | self.name = name |
||
| 46 | self.triggered = False |
||
| 47 | self.trigger_reason = "" |
||
| 48 | self.logger = logging.getLogger(f"{__name__}.{self.name}") |
||
| 49 | |||
| 50 | @abstractmethod |
||
| 51 | def should_stop(self, context: StoppingContext) -> bool: |
||
| 52 | """Check if the optimization should stop based on this condition.""" |
||
| 53 | pass |
||
| 54 | |||
| 55 | @abstractmethod |
||
| 56 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
| 57 | """Return detailed information for debugging purposes.""" |
||
| 58 | pass |
||
| 59 | |||
| 60 | def reset(self): |
||
| 61 | """Reset the condition to its initial state.""" |
||
| 62 | self.triggered = False |
||
| 63 | self.trigger_reason = "" |
||
| 64 | |||
| 65 | |||
| 66 | View Code Duplication | class TimeExceededCondition(StoppingCondition): |
|
|
|
|||
| 67 | """Stops when maximum time limit is exceeded.""" |
||
| 68 | |||
| 69 | def __init__(self, max_time: Optional[float]): |
||
| 70 | super().__init__("TimeExceeded") |
||
| 71 | self.max_time = max_time |
||
| 72 | |||
| 73 | def should_stop(self, context: StoppingContext) -> bool: |
||
| 74 | if self.max_time is None: |
||
| 75 | return False |
||
| 76 | |||
| 77 | if context.elapsed_time > self.max_time: |
||
| 78 | self.triggered = True |
||
| 79 | self.trigger_reason = f"Time limit exceeded: {context.elapsed_time:.2f}s > {self.max_time:.2f}s" |
||
| 80 | self.logger.info(self.trigger_reason) |
||
| 81 | return True |
||
| 82 | return False |
||
| 83 | |||
| 84 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
| 85 | return { |
||
| 86 | "condition": self.name, |
||
| 87 | "max_time": self.max_time, |
||
| 88 | "elapsed_time": context.elapsed_time, |
||
| 89 | "time_remaining": ( |
||
| 90 | self.max_time - context.elapsed_time if self.max_time else None |
||
| 91 | ), |
||
| 92 | "triggered": self.triggered, |
||
| 93 | "reason": self.trigger_reason, |
||
| 94 | } |
||
| 95 | |||
| 96 | |||
| 97 | View Code Duplication | class ScoreExceededCondition(StoppingCondition): |
|
| 98 | """Stops when target score is reached or exceeded.""" |
||
| 99 | |||
| 100 | def __init__(self, max_score: Optional[float]): |
||
| 101 | super().__init__("ScoreExceeded") |
||
| 102 | self.max_score = max_score |
||
| 103 | |||
| 104 | def should_stop(self, context: StoppingContext) -> bool: |
||
| 105 | if self.max_score is None: |
||
| 106 | return False |
||
| 107 | |||
| 108 | if context.score_best >= self.max_score: |
||
| 109 | self.triggered = True |
||
| 110 | self.trigger_reason = f"Target score reached: {context.score_best:.6f} >= {self.max_score:.6f}" |
||
| 111 | self.logger.info(self.trigger_reason) |
||
| 112 | return True |
||
| 113 | return False |
||
| 114 | |||
| 115 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
| 116 | return { |
||
| 117 | "condition": self.name, |
||
| 118 | "max_score": self.max_score, |
||
| 119 | "current_best_score": context.score_best, |
||
| 120 | "score_gap": ( |
||
| 121 | self.max_score - context.score_best if self.max_score else None |
||
| 122 | ), |
||
| 123 | "triggered": self.triggered, |
||
| 124 | "reason": self.trigger_reason, |
||
| 125 | } |
||
| 126 | |||
| 127 | |||
| 128 | class NoImprovementCondition(StoppingCondition): |
||
| 129 | """Stops when no improvement is observed for a specified number of iterations.""" |
||
| 130 | |||
| 131 | def __init__( |
||
| 132 | self, |
||
| 133 | n_iter_no_change: int, |
||
| 134 | tol_abs: Optional[float] = None, |
||
| 135 | tol_rel: Optional[float] = None, |
||
| 136 | ): |
||
| 137 | super().__init__("NoImprovement") |
||
| 138 | self.n_iter_no_change = n_iter_no_change |
||
| 139 | self.tol_abs = tol_abs |
||
| 140 | self.tol_rel = tol_rel |
||
| 141 | |||
| 142 | def should_stop(self, context: StoppingContext) -> bool: |
||
| 143 | if len(context.score_history) <= self.n_iter_no_change: |
||
| 144 | return False |
||
| 145 | |||
| 146 | iterations_stale = context.iterations_since_improvement |
||
| 147 | |||
| 148 | if iterations_stale >= self.n_iter_no_change: |
||
| 149 | self.triggered = True |
||
| 150 | self.trigger_reason = f"No improvement for {iterations_stale} iterations" |
||
| 151 | self.logger.info(self.trigger_reason) |
||
| 152 | return True |
||
| 153 | |||
| 154 | # Check tolerance-based early stopping |
||
| 155 | first_n = len(context.score_history) - self.n_iter_no_change |
||
| 156 | scores_before = context.score_history[:first_n] |
||
| 157 | |||
| 158 | if not scores_before: |
||
| 159 | return False |
||
| 160 | |||
| 161 | max_score_before = max(scores_before) |
||
| 162 | current_best = context.score_best |
||
| 163 | |||
| 164 | # Absolute tolerance check |
||
| 165 | if self.tol_abs is not None: |
||
| 166 | improvement = abs(current_best - max_score_before) |
||
| 167 | if improvement < self.tol_abs: |
||
| 168 | self.triggered = True |
||
| 169 | self.trigger_reason = f"Improvement below absolute tolerance: {improvement:.6f} < {self.tol_abs:.6f}" |
||
| 170 | self.logger.info(self.trigger_reason) |
||
| 171 | return True |
||
| 172 | |||
| 173 | # Relative tolerance check |
||
| 174 | if self.tol_rel is not None and max_score_before != 0: |
||
| 175 | improvement_pct = ( |
||
| 176 | (current_best - max_score_before) / abs(max_score_before) |
||
| 177 | ) * 100 |
||
| 178 | if improvement_pct < self.tol_rel: |
||
| 179 | self.triggered = True |
||
| 180 | self.trigger_reason = f"Improvement below relative tolerance: {improvement_pct:.2f}% < {self.tol_rel:.2f}%" |
||
| 181 | self.logger.info(self.trigger_reason) |
||
| 182 | return True |
||
| 183 | |||
| 184 | return False |
||
| 185 | |||
| 186 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
| 187 | iterations_stale = context.iterations_since_improvement |
||
| 188 | |||
| 189 | debug_info = { |
||
| 190 | "condition": self.name, |
||
| 191 | "n_iter_no_change": self.n_iter_no_change, |
||
| 192 | "iterations_since_improvement": iterations_stale, |
||
| 193 | "tol_abs": self.tol_abs, |
||
| 194 | "tol_rel": self.tol_rel, |
||
| 195 | "triggered": self.triggered, |
||
| 196 | "reason": self.trigger_reason, |
||
| 197 | } |
||
| 198 | |||
| 199 | if len(context.score_history) > self.n_iter_no_change: |
||
| 200 | first_n = len(context.score_history) - self.n_iter_no_change |
||
| 201 | scores_before = context.score_history[:first_n] |
||
| 202 | if scores_before: |
||
| 203 | max_score_before = max(scores_before) |
||
| 204 | improvement = context.score_best - max_score_before |
||
| 205 | debug_info["improvement_abs"] = improvement |
||
| 206 | if max_score_before != 0: |
||
| 207 | debug_info["improvement_rel_pct"] = ( |
||
| 208 | improvement / abs(max_score_before) |
||
| 209 | ) * 100 |
||
| 210 | |||
| 211 | return debug_info |
||
| 212 | |||
| 213 | |||
| 214 | class CompositeStoppingCondition(StoppingCondition): |
||
| 215 | """Combines multiple stopping conditions with OR logic.""" |
||
| 216 | |||
| 217 | def __init__(self, conditions: List[StoppingCondition]): |
||
| 218 | super().__init__("Composite") |
||
| 219 | self.conditions = conditions |
||
| 220 | |||
| 221 | def should_stop(self, context: StoppingContext) -> bool: |
||
| 222 | for condition in self.conditions: |
||
| 223 | if condition.should_stop(context): |
||
| 224 | self.triggered = True |
||
| 225 | self.trigger_reason = ( |
||
| 226 | f"Stopped by {condition.name}: {condition.trigger_reason}" |
||
| 227 | ) |
||
| 228 | self.logger.info(self.trigger_reason) |
||
| 229 | return True |
||
| 230 | return False |
||
| 231 | |||
| 232 | def get_debug_info(self, context: StoppingContext) -> Dict[str, Any]: |
||
| 233 | return { |
||
| 234 | "condition": self.name, |
||
| 235 | "triggered": self.triggered, |
||
| 236 | "reason": self.trigger_reason, |
||
| 237 | "sub_conditions": [ |
||
| 238 | condition.get_debug_info(context) for condition in self.conditions |
||
| 239 | ], |
||
| 240 | } |
||
| 241 | |||
| 242 | def reset(self): |
||
| 243 | super().reset() |
||
| 244 | for condition in self.conditions: |
||
| 245 | condition.reset() |
||
| 246 | |||
| 247 | |||
| 248 | class OptimizationStopper: |
||
| 249 | """ |
||
| 250 | Main class for managing optimization stopping conditions. |
||
| 251 | Provides a clean interface and comprehensive debugging capabilities. |
||
| 252 | """ |
||
| 253 | |||
| 254 | def __init__( |
||
| 255 | self, |
||
| 256 | start_time: float, |
||
| 257 | max_time: Optional[float] = None, |
||
| 258 | max_score: Optional[float] = None, |
||
| 259 | early_stopping: Optional[Dict[str, Any]] = None, |
||
| 260 | ): |
||
| 261 | self.start_time = start_time |
||
| 262 | self.conditions: List[StoppingCondition] = [] |
||
| 263 | self.score_history: List[float] = [] |
||
| 264 | self.score_best = -np.inf |
||
| 265 | self.iteration = 0 |
||
| 266 | self.logger = logging.getLogger(f"{__name__}.OptimizationStopper") |
||
| 267 | |||
| 268 | # Build stopping conditions |
||
| 269 | if max_time is not None: |
||
| 270 | self.conditions.append(TimeExceededCondition(max_time)) |
||
| 271 | |||
| 272 | if max_score is not None: |
||
| 273 | self.conditions.append(ScoreExceededCondition(max_score)) |
||
| 274 | |||
| 275 | if early_stopping is not None: |
||
| 276 | n_iter = early_stopping.get("n_iter_no_change") |
||
| 277 | if n_iter is not None: |
||
| 278 | self.conditions.append( |
||
| 279 | NoImprovementCondition( |
||
| 280 | n_iter_no_change=n_iter, |
||
| 281 | tol_abs=early_stopping.get("tol_abs"), |
||
| 282 | tol_rel=early_stopping.get("tol_rel"), |
||
| 283 | ) |
||
| 284 | ) |
||
| 285 | |||
| 286 | self.composite_condition = CompositeStoppingCondition(self.conditions) |
||
| 287 | |||
| 288 | def update(self, score_current: float, score_best: float, iteration: int): |
||
| 289 | """Update the stopper with new optimization state.""" |
||
| 290 | self.score_history.append(score_current) |
||
| 291 | self.score_best = score_best |
||
| 292 | self.iteration = iteration |
||
| 293 | |||
| 294 | def should_stop(self) -> bool: |
||
| 295 | """Check if optimization should stop.""" |
||
| 296 | context = StoppingContext( |
||
| 297 | iteration=self.iteration, |
||
| 298 | score_current=self.score_history[-1] if self.score_history else -np.inf, |
||
| 299 | score_best=self.score_best, |
||
| 300 | score_history=self.score_history, |
||
| 301 | start_time=self.start_time, |
||
| 302 | current_time=time.time(), |
||
| 303 | ) |
||
| 304 | |||
| 305 | return self.composite_condition.should_stop(context) |
||
| 306 | |||
| 307 | def get_debug_info(self) -> Dict[str, Any]: |
||
| 308 | """Get comprehensive debugging information about stopping conditions.""" |
||
| 309 | context = StoppingContext( |
||
| 310 | iteration=self.iteration, |
||
| 311 | score_current=self.score_history[-1] if self.score_history else -np.inf, |
||
| 312 | score_best=self.score_best, |
||
| 313 | score_history=self.score_history, |
||
| 314 | start_time=self.start_time, |
||
| 315 | current_time=time.time(), |
||
| 316 | ) |
||
| 317 | |||
| 318 | return self.composite_condition.get_debug_info(context) |
||
| 319 | |||
| 320 | def get_stop_reason(self) -> str: |
||
| 321 | """Get a human-readable reason for why optimization stopped.""" |
||
| 322 | if self.composite_condition.triggered: |
||
| 323 | return self.composite_condition.trigger_reason |
||
| 324 | return "Optimization not stopped by stopper" |
||
| 325 |