Passed
Push — master ( 8fb723...24a802 )
by Simon
01:48
created

hyperactive.hyperactive.Hyperactive.run()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import multiprocessing as mp
7
from tqdm import tqdm
8
9
from .optimizers import RandomSearchOptimizer
10
from .run_search import run_search
11
12
from .results import Results
13
from .print_results import PrintResults
14
from .search_space import SearchSpace
15
16
17
class Hyperactive:
18
    def __init__(
19
        self,
20
        verbosity=["progress_bar", "print_results", "print_times"],
21
        distribution="multiprocessing",
22
        n_processes="auto",
23
    ):
24
        super().__init__()
25
        if verbosity is False:
26
            verbosity = []
27
28
        self.verbosity = verbosity
29
        self.distribution = distribution
30
        self.n_processes = n_processes
31
32
        self.opt_pros = {}
33
34
    def _create_shared_memory(self, new_opt):
35
        if new_opt.memory == "share":
36
            if len(self.opt_pros) == 0:
37
38
                manager = mp.Manager()
39
                new_opt.memory = manager.dict()
40
41
            for opt in self.opt_pros.values():
42
                same_obj_func = (
43
                    opt.objective_function.__name__
44
                    == new_opt.objective_function.__name__
45
                )
46
                same_ss_length = len(opt.s_space()) == len(new_opt.s_space())
47
48
                if same_obj_func and same_ss_length:
49
                    new_opt.memory = opt.memory  # get same manager.dict
50
                else:
51
                    manager = mp.Manager()  # get new manager.dict
52
                    new_opt.memory = manager.dict()
53
54
    @staticmethod
55
    def _default_opt(optimizer):
56
        if isinstance(optimizer, str):
57
            if optimizer == "default":
58
                optimizer = RandomSearchOptimizer()
59
        return optimizer
60
61
    @staticmethod
62
    def _default_search_id(search_id, objective_function):
63
        if not search_id:
64
            search_id = objective_function.__name__
65
        return search_id
66
67
    @staticmethod
68
    def check_list(search_space):
69
        for key in search_space.keys():
70
            search_dim = search_space[key]
71
72
            error_msg = (
73
                "Value in '{}' of search space dictionary must be of type list".format(
74
                    key
75
                )
76
            )
77
            if not isinstance(search_dim, list):
78
                print("Warning", error_msg)
79
                # raise ValueError(error_msg)
80
81
    def add_search(
82
        self,
83
        objective_function,
84
        search_space,
85
        n_iter,
86
        search_id=None,
87
        optimizer="default",
88
        n_jobs=1,
89
        initialize={"grid": 4, "random": 2, "vertices": 4},
90
        pass_through={},
91
        max_score=None,
92
        early_stopping=None,
93
        random_state=None,
94
        memory="share",
95
        memory_warm_start=None,
96
    ):
97
        self.check_list(search_space)
98
99
        optimizer = self._default_opt(optimizer)
100
        search_id = self._default_search_id(search_id, objective_function)
101
        s_space = SearchSpace(search_space)
102
103
        optimizer.setup_search(
104
            objective_function,
105
            s_space,
106
            n_iter,
107
            initialize,
108
            pass_through,
109
            max_score,
110
            early_stopping,
111
            random_state,
112
            memory,
113
            memory_warm_start,
114
            self.verbosity,
115
        )
116
117
        if memory == "share":
118
            self._create_shared_memory(optimizer)
119
120
        if n_jobs == -1:
121
            n_jobs = mp.cpu_count()
122
123
        for _ in range(n_jobs):
124
            nth_process = len(self.opt_pros)
125
            self.opt_pros[nth_process] = optimizer
126
127
    def _print_info(self):
128
        print_res = PrintResults(self.opt_pros, self.verbosity)
129
130
        for results in self.results_list:
131
            nth_process = results["nth_process"]
132
            print_res.print_process(results, nth_process)
133
134
    def run(self, max_time=None):
135
        for opt in self.opt_pros.values():
136
            opt.max_time = max_time
137
138
        self.results_list = run_search(
139
            self.opt_pros, self.distribution, self.n_processes
140
        )
141
142
        self.results_ = Results(self.results_list, self.opt_pros)
143
144
        self._print_info()
145
146
    def best_para(self, id_):
147
        return self.results_.best_para(id_)
148
149
    def best_score(self, id_):
150
        return self.results_.best_score(id_)
151
152
    def search_data(self, id_):
153
        return self.results_.search_data(id_)
154