Passed
Push — master ( 41483f...5037dc )
by Simon
01:18
created

hyperactive.hyperactive   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 119
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 85
dl 0
loc 119
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A Hyperactive.get_eval_time() 0 2 1
A Hyperactive.save_report() 0 2 1
A Hyperactive.get_total_time() 0 2 1
A Hyperactive.__init__() 0 20 2
A Hyperactive.search() 0 50 4

1 Function

Rating   Name   Duplication   Size   Complexity  
A stop_warnings() 0 8 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import warnings
7
8
from .main_args import MainArgs
9
from .opt_args import Arguments
10
11
from . import (
12
    HillClimbingOptimizer,
13
    StochasticHillClimbingOptimizer,
14
    TabuOptimizer,
15
    RandomSearchOptimizer,
16
    RandomRestartHillClimbingOptimizer,
17
    RandomAnnealingOptimizer,
18
    SimulatedAnnealingOptimizer,
19
    StochasticTunnelingOptimizer,
20
    ParallelTemperingOptimizer,
21
    ParticleSwarmOptimizer,
22
    EvolutionStrategyOptimizer,
23
    BayesianOptimizer,
24
)
25
26
27
def stop_warnings():
28
    # because sklearn warnings are annoying when they appear 100 times
29
    def warn(*args, **kwargs):
30
        pass
31
32
    import warnings
33
34
    warnings.warn = warn
35
36
37
class Hyperactive:
38
    def __init__(self, X, y, memory=True, random_state=1, verbosity=3, warnings=False):
39
        self.X = X
40
        self._main_args_ = MainArgs(X, y, memory, random_state, verbosity)
41
42
        if not warnings:
43
            stop_warnings()
44
45
        self.optimizer_dict = {
46
            "HillClimbing": HillClimbingOptimizer,
47
            "StochasticHillClimbing": StochasticHillClimbingOptimizer,
48
            "TabuSearch": TabuOptimizer,
49
            "RandomSearch": RandomSearchOptimizer,
50
            "RandomRestartHillClimbing": RandomRestartHillClimbingOptimizer,
51
            "RandomAnnealing": RandomAnnealingOptimizer,
52
            "SimulatedAnnealing": SimulatedAnnealingOptimizer,
53
            "StochasticTunneling": StochasticTunnelingOptimizer,
54
            "ParallelTempering": ParallelTemperingOptimizer,
55
            "ParticleSwarm": ParticleSwarmOptimizer,
56
            "EvolutionStrategy": EvolutionStrategyOptimizer,
57
            "Bayesian": BayesianOptimizer,
58
        }
59
60
    def search(
61
        self,
62
        search_config,
63
        n_iter=10,
64
        max_time=None,
65
        optimizer="RandomSearch",
66
        n_jobs=1,
67
        init_config=None,
68
    ):
69
70
        start_time = time.time()
71
72
        self._main_args_.search_args(
73
            search_config, max_time, n_iter, optimizer, n_jobs, init_config
74
        )
75
        self._opt_args_ = Arguments(self._main_args_.opt_para)
76
        optimizer_class = self.optimizer_dict[self._main_args_.optimizer]
77
78
        try:
79
            import ray
80
81
            if ray.is_initialized():
82
                ray_ = True
83
            else:
84
                ray_ = False
85
        except ImportError:
86
            warnings.warn("failed to import ray", ImportWarning)
87
            ray_ = False
88
89
        if ray_:
90
            optimizer_class = ray.remote(optimizer_class)
91
            opts = [
92
                optimizer_class.remote(self._main_args_, self._opt_args_)
93
                for job in range(self._main_args_.n_jobs)
94
            ]
95
            searches = [
96
                opt.search.remote(job, ray_=ray_) for job, opt in enumerate(opts)
97
            ]
98
            ray.get(searches)
99
        else:
100
            self._optimizer_ = optimizer_class(self._main_args_, self._opt_args_)
101
            self._optimizer_.search()
102
103
        self.results_params = self._optimizer_.results_params
104
        self.results_models = self._optimizer_.results_models
105
106
        self.pos_list = self._optimizer_.pos_list
107
        self.score_list = self._optimizer_.score_list
108
109
        self.total_time = time.time() - start_time
110
111
    def get_total_time(self):
112
        return self.total_time
113
114
    def get_eval_time(self):
115
        return self._optimizer_.eval_time
116
117
    def save_report(self):
118
        pass
119