Passed
Push — master ( 1a4396...e9038c )
by Simon
01:47
created

hyperactive.hyperactive.try_ray_import()   A

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 13
rs 9.9
c 0
b 0
f 0
cc 3
nop 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import warnings
7
8
from .main_args import MainArgs
9
from .opt_args import Arguments
10
11
from . import (
12
    HillClimbingOptimizer,
13
    StochasticHillClimbingOptimizer,
14
    TabuOptimizer,
15
    RandomSearchOptimizer,
16
    RandomRestartHillClimbingOptimizer,
17
    RandomAnnealingOptimizer,
18
    SimulatedAnnealingOptimizer,
19
    StochasticTunnelingOptimizer,
20
    ParallelTemperingOptimizer,
21
    ParticleSwarmOptimizer,
22
    EvolutionStrategyOptimizer,
23
    BayesianOptimizer,
24
)
25
26
27
def stop_warnings():
28
    # because sklearn warnings are annoying when they appear 100 times
29
    def warn(*args, **kwargs):
30
        pass
31
32
    import warnings
33
34
    warnings.warn = warn
35
36
37
def try_ray_import():
38
    try:
39
        import ray
40
41
        if ray.is_initialized():
42
            rayInit = True
43
        else:
44
            rayInit = False
45
    except ImportError:
46
        warnings.warn("failed to import ray", ImportWarning)
47
        rayInit = False
48
49
    return ray, rayInit
50
51
52
class Hyperactive:
53
    def __init__(
54
        self, X, y, memory="long", random_state=1, verbosity=3, warnings=False
55
    ):
56
        self.X = X
57
        self._main_args_ = MainArgs(X, y, memory, random_state, verbosity)
58
59
        if not warnings:
60
            stop_warnings()
61
62
        self.optimizer_dict = {
63
            "HillClimbing": HillClimbingOptimizer,
64
            "StochasticHillClimbing": StochasticHillClimbingOptimizer,
65
            "TabuSearch": TabuOptimizer,
66
            "RandomSearch": RandomSearchOptimizer,
67
            "RandomRestartHillClimbing": RandomRestartHillClimbingOptimizer,
68
            "RandomAnnealing": RandomAnnealingOptimizer,
69
            "SimulatedAnnealing": SimulatedAnnealingOptimizer,
70
            "StochasticTunneling": StochasticTunnelingOptimizer,
71
            "ParallelTempering": ParallelTemperingOptimizer,
72
            "ParticleSwarm": ParticleSwarmOptimizer,
73
            "EvolutionStrategy": EvolutionStrategyOptimizer,
74
            "Bayesian": BayesianOptimizer,
75
        }
76
77
    def search(
78
        self,
79
        search_config,
80
        n_iter=10,
81
        max_time=None,
82
        optimizer="RandomSearch",
83
        n_jobs=1,
84
        init_config=None,
85
    ):
86
87
        start_time = time.time()
88
89
        self._main_args_.search_args(
90
            search_config, max_time, n_iter, optimizer, n_jobs, init_config
91
        )
92
        self._opt_args_ = Arguments(self._main_args_.opt_para)
93
        optimizer_class = self.optimizer_dict[self._main_args_.optimizer]
94
95
        ray, rayInit = try_ray_import()
96
97
        if rayInit:
98
            optimizer_class = ray.remote(optimizer_class)
99
            opts = [
100
                optimizer_class.remote(self._main_args_, self._opt_args_)
101
                for job in range(self._main_args_.n_jobs)
102
            ]
103
            searches = [
104
                opt.search.remote(job, rayInit=rayInit) for job, opt in enumerate(opts)
105
            ]
106
            self.results_params, self.results_models, self.pos_list, self.score_list, self.eval_time = ray.get(
107
                searches
108
            )[
109
                0
110
            ]
111
112
            ray.shutdown()
113
        else:
114
            self._optimizer_ = optimizer_class(self._main_args_, self._opt_args_)
115
            self.results_params, self.results_models, self.pos_list, self.score_list, self.eval_time = (
116
                self._optimizer_.search()
117
            )
118
119
        self.total_time = time.time() - start_time
120