Passed
Push — master ( 7596de...7d2c3d )
by Simon
04:29
created

hyperactive.hyperactive.Hyperactive.run()   B

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 3
dl 0
loc 27
rs 8.9332
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import multiprocessing
7
from tqdm import tqdm
8
9
from .optimizers import RandomSearchOptimizer
10
from .run_search import run_search
11
from .print_info import print_info
12
13
from .hyperactive_results import HyperactiveResults
14
15
16
class Hyperactive(HyperactiveResults):
17
    def __init__(
18
        self,
19
        verbosity=["progress_bar", "print_results", "print_times"],
20
        distribution={
21
            "multiprocessing": {
22
                "initializer": tqdm.set_lock,
23
                "initargs": (tqdm.get_lock(),),
24
            }
25
        },
26
        n_processes="auto",
27
    ):
28
        super().__init__()
29
        if verbosity is False:
30
            verbosity = []
31
32
        self.verbosity = verbosity
33
        self.distribution = distribution
34
        self.n_processes = n_processes
35
36
        self.search_ids = []
37
        self.process_infos = {}
38
39
        self.progress_boards = {}
40
41
    def _add_search_processes(
42
        self,
43
        random_state,
44
        objective_function,
45
        optimizer,
46
        n_iter,
47
        n_jobs,
48
        max_score,
49
        memory,
50
        memory_warm_start,
51
        search_id,
52
    ):
53
        if n_jobs == -1:
54
            n_jobs = multiprocessing.cpu_count()
55
56
        for _ in range(n_jobs):
57
            nth_process = len(self.process_infos)
58
59
            self.process_infos[nth_process] = {
60
                "random_state": random_state,
61
                "verbosity": self.verbosity,
62
                "nth_process": nth_process,
63
                "objective_function": objective_function,
64
                "optimizer": optimizer,
65
                "n_iter": n_iter,
66
                "max_score": max_score,
67
                "memory": memory,
68
                "memory_warm_start": memory_warm_start,
69
                "search_id": search_id,
70
            }
71
72
    def _default_opt(self, optimizer):
73
        if isinstance(optimizer, str):
74
            if optimizer == "default":
75
                optimizer = RandomSearchOptimizer()
76
        return optimizer
77
78
    def _default_search_id(self, search_id, objective_function):
79
        if not search_id:
80
            search_id = objective_function.__name__
81
        return search_id
82
83
    def _init_progress_board(self, progress_board, search_id, search_space):
84
        data_c = None
85
86
        if progress_board:
87
            data_c = progress_board.init_paths(search_id, search_space)
88
89
            if progress_board.uuid not in self.progress_boards:
90
                self.progress_boards[progress_board.uuid] = progress_board
91
92
        return data_c
93
94
    @staticmethod
95
    def check_list(search_space):
96
        for key in search_space.keys():
97
            search_dim = search_space[key]
98
99
            error_msg = (
100
                "Value in '{}' of search space dictionary must be of type list".format(
101
                    key
102
                )
103
            )
104
105
            if not isinstance(search_dim, list):
106
                print("Warning", error_msg)
107
                # raise ValueError(error_msg)
108
109
    def add_search(
110
        self,
111
        objective_function,
112
        search_space,
113
        n_iter,
114
        search_id=None,
115
        optimizer="default",
116
        n_jobs=1,
117
        initialize={"grid": 4, "random": 2, "vertices": 4},
118
        max_score=None,
119
        random_state=None,
120
        memory=True,
121
        memory_warm_start=None,
122
        progress_board=None,
123
    ):
124
        optimizer = self._default_opt(optimizer)
125
        search_id = self._default_search_id(search_id, objective_function)
126
        progress_collector = self._init_progress_board(
127
            progress_board, search_id, search_space
128
        )
129
130
        self.check_list(search_space)
131
132
        optimizer.init(search_space, initialize, progress_collector)
133
134
        self._add_search_processes(
135
            random_state,
136
            objective_function,
137
            optimizer,
138
            n_iter,
139
            n_jobs,
140
            max_score,
141
            memory,
142
            memory_warm_start,
143
            search_id,
144
        )
145
146
    def run(self, max_time=None, _test_st_backend=False):
147
        for nth_process in self.process_infos.keys():
148
            self.process_infos[nth_process]["max_time"] = max_time
149
150
        # open progress board
151
        if not _test_st_backend:
152
            for progress_board in self.progress_boards.values():
153
                progress_board.open_dashboard()
154
155
        self.results_list = run_search(
156
            self.process_infos, self.distribution, self.n_processes
157
        )
158
159
        for results in self.results_list:
160
            nth_process = results["nth_process"]
161
162
            print_info(
163
                verbosity=self.process_infos[nth_process]["verbosity"],
164
                objective_function=self.process_infos[nth_process][
165
                    "objective_function"
166
                ],
167
                best_score=results["best_score"],
168
                best_para=results["best_para"],
169
                best_iter=results["best_iter"],
170
                eval_times=results["eval_times"],
171
                iter_times=results["iter_times"],
172
                n_iter=self.process_infos[nth_process]["n_iter"],
173
            )
174