Passed
Push — master ( b4b259...5acf24 )
by Simon
05:13
created

Hyperactive._add_search_processes()   A

Complexity

Conditions 3

Size

Total Lines 29
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 26
nop 10
dl 0
loc 29
rs 9.256
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import multiprocessing
7
from tqdm import tqdm
8
9
from .optimizers import RandomSearchOptimizer
10
from .run_search import run_search
11
from .print_info import print_info
12
13
from .hyperactive_results import HyperactiveResults
14
15
16
class Hyperactive(HyperactiveResults):
17
    def __init__(
18
        self,
19
        verbosity=["progress_bar", "print_results", "print_times"],
20
        distribution={
21
            "multiprocessing": {
22
                "initializer": tqdm.set_lock,
23
                "initargs": (tqdm.get_lock(),),
24
            }
25
        },
26
        n_processes="auto",
27
    ):
28
        super().__init__()
29
        if verbosity is False:
30
            verbosity = []
31
32
        self.verbosity = verbosity
33
        self.distribution = distribution
34
        self.n_processes = n_processes
35
36
        self.search_ids = []
37
        self.process_infos = {}
38
39
        self.progress_boards = {}
40
41
    def _add_search_processes(
42
        self,
43
        random_state,
44
        objective_function,
45
        optimizer,
46
        n_iter,
47
        n_jobs,
48
        max_score,
49
        memory,
50
        memory_warm_start,
51
        search_id,
52
    ):
53
        if n_jobs == -1:
54
            n_jobs = multiprocessing.cpu_count()
55
56
        for _ in range(n_jobs):
57
            nth_process = len(self.process_infos)
58
59
            self.process_infos[nth_process] = {
60
                "random_state": random_state,
61
                "verbosity": self.verbosity,
62
                "nth_process": nth_process,
63
                "objective_function": objective_function,
64
                "optimizer": optimizer,
65
                "n_iter": n_iter,
66
                "max_score": max_score,
67
                "memory": memory,
68
                "memory_warm_start": memory_warm_start,
69
                "search_id": search_id,
70
            }
71
72
    def _default_opt(self, optimizer):
73
        if isinstance(optimizer, str):
74
            if optimizer == "default":
75
                optimizer = RandomSearchOptimizer()
76
        return optimizer
77
78
    def _default_search_id(self, search_id, objective_function):
79
        if not search_id:
80
            search_id = objective_function.__name__
81
        return search_id
82
83
    def _init_progress_board(self, progress_board, search_id, search_space):
84
        data_c = None
85
86
        if progress_board:
87
            data_c = progress_board.init_paths(search_id, search_space)
88
89
            if progress_board.uuid not in self.progress_boards:
90
                self.progress_boards[progress_board.uuid] = progress_board
91
92
        return data_c
93
94
    @staticmethod
95
    def check_list(search_space):
96
        for key in search_space.keys():
97
            search_dim = search_space[key]
98
99
            error_msg = (
100
                "Value in '{}' of search space dictionary must be of type list".format(
101
                    key
102
                )
103
            )
104
105
            if not isinstance(search_dim, list):
106
                print("Warning", error_msg)
107
                # raise ValueError(error_msg)
108
109
    def add_search(
110
        self,
111
        objective_function,
112
        search_space,
113
        n_iter,
114
        search_id=None,
115
        optimizer="default",
116
        n_jobs=1,
117
        initialize={"grid": 4, "random": 2, "vertices": 4},
118
        max_score=None,
119
        random_state=None,
120
        memory=True,
121
        memory_warm_start=None,
122
        progress_board=None,
123
    ):
124
        optimizer = self._default_opt(optimizer)
125
        search_id = self._default_search_id(search_id, objective_function)
126
        progress_collector = self._init_progress_board(
127
            progress_board, search_id, search_space
128
        )
129
130
        self.check_list(search_space)
131
132
        optimizer.init(search_space, initialize, progress_collector)
133
134
        self._add_search_processes(
135
            random_state,
136
            objective_function,
137
            optimizer,
138
            n_iter,
139
            n_jobs,
140
            max_score,
141
            memory,
142
            memory_warm_start,
143
            search_id,
144
        )
145
146
    def _print_info(self):
147
        for results in self.results_list:
148
            nth_process = results["nth_process"]
149
150
            print_info(
151
                verbosity=self.process_infos[nth_process]["verbosity"],
152
                objective_function=self.process_infos[nth_process][
153
                    "objective_function"
154
                ],
155
                best_score=results["best_score"],
156
                best_para=results["best_para"],
157
                best_iter=results["best_iter"],
158
                eval_times=results["eval_times"],
159
                iter_times=results["iter_times"],
160
                n_iter=self.process_infos[nth_process]["n_iter"],
161
            )
162
163
    def run(self, max_time=None, _test_st_backend=False):
164
        for nth_process in self.process_infos.keys():
165
            self.process_infos[nth_process]["max_time"] = max_time
166
167
        # open progress board
168
        if not _test_st_backend:
169
            for progress_board in self.progress_boards.values():
170
                progress_board.open_dashboard()
171
172
        self.results_list = run_search(
173
            self.process_infos, self.distribution, self.n_processes
174
        )
175
176
        # delete lock files
177
        if not _test_st_backend:
178
            for progress_board in self.progress_boards.values():
179
                for search_id in progress_board.search_ids:
180
                    progress_board._io_.remove_lock(search_id)
181
182
        self._print_info()
183