Passed
Push — master ( cba5ed...7e35c2 )
by Simon
04:17
created

HyperactiveResults._sort_results_objFunc()   B

Complexity

Conditions 5

Size

Total Lines 29
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 21
nop 2
dl 0
loc 29
rs 8.9093
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
from tqdm import tqdm
7
8
from .optimizers import RandomSearchOptimizer
9
from .run_search import run_search
10
from .print_info import print_info
11
12
from .hyperactive_results import HyperactiveResults
13
14
15
class Hyperactive(HyperactiveResults):
16
    def __init__(
17
        self,
18
        verbosity=["progress_bar", "print_results", "print_times"],
19
        distribution={
20
            "multiprocessing": {
21
                "initializer": tqdm.set_lock,
22
                "initargs": (tqdm.get_lock(),),
23
            }
24
        },
25
        n_processes="auto",
26
    ):
27
        super().__init__()
28
        if verbosity is False:
29
            verbosity = []
30
31
        self.verbosity = verbosity
32
        self.distribution = distribution
33
        self.n_processes = n_processes
34
35
        self.search_ids = []
36
        self.process_infos = {}
37
        self.objFunc2results = {}
38
        self.search_id2results = {}
39
40
        self.progress_boards = {}
41
42
    def _add_search_processes(
43
        self,
44
        random_state,
45
        objective_function,
46
        optimizer,
47
        n_iter,
48
        n_jobs,
49
        max_score,
50
        memory,
51
        memory_warm_start,
52
        search_id,
53
    ):
54
        for _ in range(n_jobs):
55
            nth_process = len(self.process_infos)
56
57
            self.process_infos[nth_process] = {
58
                "random_state": random_state,
59
                "verbosity": self.verbosity,
60
                "nth_process": nth_process,
61
                "objective_function": objective_function,
62
                "optimizer": optimizer,
63
                "n_iter": n_iter,
64
                "max_score": max_score,
65
                "memory": memory,
66
                "memory_warm_start": memory_warm_start,
67
                "search_id": search_id,
68
            }
69
70
    def _default_opt(self, optimizer):
71
        if isinstance(optimizer, str):
72
            if optimizer == "default":
73
                optimizer = RandomSearchOptimizer()
74
        return optimizer
75
76
    def _default_search_id(self, search_id, objective_function):
77
        if not search_id:
78
            search_id = objective_function.__name__
79
        return search_id
80
81
    def _init_progress_board(self, progress_board, search_id, search_space):
82
        data_c = None
83
84
        if progress_board:
85
            data_c = progress_board.init_paths(search_id, search_space)
86
87
            if progress_board.uuid not in self.progress_boards:
88
                self.progress_boards[progress_board.uuid] = progress_board
89
90
        return data_c
91
92
    def add_search(
93
        self,
94
        objective_function,
95
        search_space,
96
        n_iter,
97
        search_id=None,
98
        optimizer="default",
99
        n_jobs=1,
100
        initialize={"grid": 4, "random": 2, "vertices": 4},
101
        max_score=None,
102
        random_state=None,
103
        memory=True,
104
        memory_warm_start=None,
105
        progress_board=None,
106
    ):
107
        optimizer = self._default_opt(optimizer)
108
        search_id = self._default_search_id(search_id, objective_function)
109
        data_c = self._init_progress_board(progress_board, search_id, search_space)
110
111
        optimizer.init(search_space, initialize, data_c)
112
113
        self._add_search_processes(
114
            random_state,
115
            objective_function,
116
            optimizer,
117
            n_iter,
118
            n_jobs,
119
            max_score,
120
            memory,
121
            memory_warm_start,
122
            search_id,
123
        )
124
125
    def run(self, max_time=None):
126
        for nth_process in self.process_infos.keys():
127
            self.process_infos[nth_process]["max_time"] = max_time
128
129
        # open progress board
130
        for progress_board in self.progress_boards.values():
131
            progress_board.open_dashboard()
132
133
        self.results_list = run_search(
134
            self.process_infos, self.distribution, self.n_processes
135
        )
136
137
        for results in self.results_list:
138
            nth_process = results["nth_process"]
139
140
            print_info(
141
                verbosity=self.process_infos[nth_process]["verbosity"],
142
                objective_function=self.process_infos[nth_process][
143
                    "objective_function"
144
                ],
145
                best_score=results["best_score"],
146
                best_para=results["best_para"],
147
                best_iter=results["best_iter"],
148
                eval_times=results["eval_times"],
149
                iter_times=results["iter_times"],
150
                n_iter=self.process_infos[nth_process]["n_iter"],
151
            )
152