Passed
Push — master ( 051bec...b89548 )
by Simon
01:31
created

HyperactiveResults._get_one_result()   A

Complexity

Conditions 4

Size

Total Lines 12
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 8
nop 3
dl 0
loc 12
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
import pandas as pd
8
from tqdm import tqdm
9
10
from .optimizers import RandomSearchOptimizer
11
from .run_search import run_search
12
from .print_info import print_info
13
14
15
class HyperactiveResults:
16
    def __init__(*args, **kwargs):
17
        pass
18
19
    def _sort_results_objFunc(self, objective_function):
20
        best_score = -np.inf
21
        best_para = None
22
        search_data = None
23
24
        results_list = []
25
26
        for results_ in self.results_list:
27
            nth_process = results_["nth_process"]
28
29
            process_infos = self.process_infos[nth_process]
30
            objective_function_ = process_infos["objective_function"]
31
32
            if objective_function_ != objective_function:
33
                continue
34
35
            if results_["best_score"] > best_score:
36
                best_score = results_["best_score"]
37
                best_para = results_["best_para"]
38
39
            results_list.append(results_["results"])
40
41
        if len(results_list) > 0:
42
            search_data = pd.concat(results_list)
43
44
        self.objFunc2results[objective_function] = {
45
            "best_para": best_para,
46
            "best_score": best_score,
47
            "search_data": search_data,
48
        }
49
50
    def _sort_results_search_id(self, search_id):
51
        for results_ in self.results_list:
52
            nth_process = results_["nth_process"]
53
            search_id_ = self.process_infos[nth_process]["search_id"]
54
55
            if search_id_ != search_id:
56
                continue
57
58
            best_score = results_["best_score"]
59
            best_para = results_["best_para"]
60
            search_data = results_["results"]
61
62
            self.search_id2results[search_id] = {
63
                "best_para": best_para,
64
                "best_score": best_score,
65
                "search_data": search_data,
66
            }
67
68
    def _get_one_result(self, id_, result_name):
69
        if isinstance(id_, str):
70
            if id_ not in self.search_id2results:
71
                self._sort_results_search_id(id_)
72
73
            return self.search_id2results[id_][result_name]
74
75
        else:
76
            if id_ not in self.objFunc2results:
77
                self._sort_results_objFunc(id_)
78
79
            return self.objFunc2results[id_][result_name]
80
81
    def best_para(self, id_):
82
        return self._get_one_result(id_, "best_para")
83
84
    def best_score(self, id_):
85
        return self._get_one_result(id_, "best_score")
86
87
    def results(self, id_):
88
        return self._get_one_result(id_, "search_data")
89
90
91
class Hyperactive(HyperactiveResults):
92
    def __init__(
93
        self,
94
        verbosity=["progress_bar", "print_results", "print_times"],
95
        distribution={
96
            "multiprocessing": {
97
                "initializer": tqdm.set_lock,
98
                "initargs": (tqdm.get_lock(),),
99
            }
100
        },
101
        n_processes="auto",
102
    ):
103
        super().__init__()
104
        if verbosity is False:
105
            verbosity = []
106
107
        self.verbosity = verbosity
108
        self.distribution = distribution
109
        self.n_processes = n_processes
110
111
        self.search_ids = []
112
        self.process_infos = {}
113
        self.objFunc2results = {}
114
        self.search_id2results = {}
115
116
    def _add_search_processes(
117
        self,
118
        random_state,
119
        objective_function,
120
        search_space,
121
        optimizer,
122
        n_iter,
123
        n_jobs,
124
        max_score,
125
        memory,
126
        memory_warm_start,
127
        search_id,
128
    ):
129
        for _ in range(n_jobs):
130
            nth_process = len(self.process_infos)
131
132
            self.process_infos[nth_process] = {
133
                "random_state": random_state,
134
                "verbosity": self.verbosity,
135
                "nth_process": nth_process,
136
                "objective_function": objective_function,
137
                "search_space": search_space,
138
                "optimizer": optimizer,
139
                "n_iter": n_iter,
140
                "max_score": max_score,
141
                "memory": memory,
142
                "memory_warm_start": memory_warm_start,
143
                "search_id": search_id,
144
            }
145
146
    def add_search(
147
        self,
148
        objective_function,
149
        search_space,
150
        n_iter,
151
        search_id=None,
152
        optimizer="default",
153
        n_jobs=1,
154
        initialize={"grid": 4, "random": 2, "vertices": 4},
155
        max_score=None,
156
        random_state=None,
157
        memory=True,
158
        memory_warm_start=None,
159
    ):
160
        if isinstance(optimizer, str):
161
            if optimizer == "default":
162
                optimizer = RandomSearchOptimizer()
163
        optimizer.init(search_space, initialize)
164
165
        if search_id is not None:
166
            search_id = search_id
167
            self.search_ids.append(search_id)
168
        else:
169
            search_id = str(len(self.search_ids))
170
            self.search_ids.append(search_id)
171
172
        self._add_search_processes(
173
            random_state,
174
            objective_function,
175
            search_space,
176
            optimizer,
177
            n_iter,
178
            n_jobs,
179
            max_score,
180
            memory,
181
            memory_warm_start,
182
            search_id,
183
        )
184
185
    def run(self, max_time=None):
186
        for nth_process in self.process_infos.keys():
187
            self.process_infos[nth_process]["max_time"] = max_time
188
189
        self.results_list = run_search(
190
            self.process_infos, self.distribution, self.n_processes
191
        )
192
193
        for results in self.results_list:
194
            nth_process = results["nth_process"]
195
196
            print_info(
197
                verbosity=self.process_infos[nth_process]["verbosity"],
198
                objective_function=self.process_infos[nth_process][
199
                    "objective_function"
200
                ],
201
                best_score=results["best_score"],
202
                best_para=results["best_para"],
203
                best_iter=results["best_iter"],
204
                eval_times=results["eval_times"],
205
                iter_times=results["iter_times"],
206
                n_iter=self.process_infos[nth_process]["n_iter"],
207
            )
208