Passed
Push — master ( 06915f...21ec56 )
by Simon
01:13
created

SearchProcess.__init__()   A

Complexity

Conditions 1

Size

Total Lines 34
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 30
nop 13
dl 0
loc 34
rs 9.16
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
import numpy as np
8
import pandas as pd
9
10
from importlib import import_module
11
12
13
optimizer_dict = {
14
    "HillClimbing": "HillClimbingOptimizer",
15
    "StochasticHillClimbing": "StochasticHillClimbingOptimizer",
16
    "TabuSearch": "TabuOptimizer",
17
    "RandomSearch": "RandomSearchOptimizer",
18
    "RandomRestartHillClimbing": "RandomRestartHillClimbingOptimizer",
19
    "RandomAnnealing": "RandomAnnealingOptimizer",
20
    "SimulatedAnnealing": "SimulatedAnnealingOptimizer",
21
    "StochasticTunneling": "StochasticTunnelingOptimizer",
22
    "ParallelTempering": "ParallelTemperingOptimizer",
23
    "ParticleSwarm": "ParticleSwarmOptimizer",
24
    "EvolutionStrategy": "EvolutionStrategyOptimizer",
25
    "Bayesian": "BayesianOptimizer",
26
    "TPE": "TreeStructuredParzenEstimators",
27
    "DecisionTree": "DecisionTreeOptimizer",
28
}
29
30
31
class SearchProcess:
32
    def __init__(
33
        self,
34
        nth_process,
35
        p_bar,
36
        model,
37
        search_space,
38
        search_name,
39
        n_iter,
40
        training_data,
41
        optimizer,
42
        n_jobs,
43
        init_para,
44
        memory,
45
        random_state,
46
    ):
47
        self.nth_process = nth_process
48
        self.p_bar = p_bar
49
        self.model = model
50
        self.search_space = search_space
51
        self.n_iter = n_iter
52
        self.training_data = training_data
53
        self.optimizer = optimizer
54
        self.n_jobs = n_jobs
55
        self.init_para = init_para
56
        self.memory = memory
57
        self.random_state = random_state
58
59
        self._process_arguments()
60
61
        self.iter_times = []
62
        self.eval_times = []
63
64
        module = import_module("gradient_free_optimizers")
65
        self.opt_class = getattr(module, optimizer_dict[optimizer])
66
67
    def _time_exceeded(self, start_time, max_time):
68
        run_time = time.time() - start_time
69
        return max_time and run_time > max_time
70
71
    def _initialize_search(self, nth_process):
72
        self._set_random_seed(nth_process)
73
74
        self.p_bar.init_p_bar(nth_process, self.n_iter, self.model)
75
        init_positions = self.cand.init.set_start_pos(self.n_positions)
76
        self.opt = self.opt_class(init_positions, self.cand.space.dim, opt_para={})
77
78
    def _process_arguments(self):
79
        if isinstance(self.optimizer, dict):
80
            optimizer = list(self.optimizer.keys())[0]
81
            self.opt_para = self.optimizer[optimizer]
82
            self.optimizer = optimizer
83
84
            self.n_positions = self._get_n_positions()
85
        else:
86
            self.opt_para = {}
87
            self.n_positions = self._get_n_positions()
88
89
    def _get_n_positions(self):
90
        n_positions_strings = [
91
            "n_positions",
92
            "system_temperatures",
93
            "n_particles",
94
            "individuals",
95
        ]
96
97
        n_positions = 1
98
        for n_pos_name in n_positions_strings:
99
            if n_pos_name in list(self.opt_para.keys()):
100
                n_positions = self.opt_para[n_pos_name]
101
                if n_positions == "system_temperatures":
102
                    n_positions = len(n_positions)
103
104
        return n_positions
105
106
    def _save_results(self):
107
        self.res.nth_process = self.nth_process
108
        self.res.eval_times = self.eval_times
109
        self.res.iter_times = self.iter_times
110
        self.res.memory_dict_new = self.cand.memory_dict_new
111
        self.res.para_best = self.cand.para_best
112
        self.res.score_best = self.cand.score_best
113
        self.res.model = self.model
114
        self.res.search_space = self.search_space
115
        self.res.memory = self.memory
116
117
    def _set_random_seed(self, nth_process):
118
        """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
119
        if self.random_state is None:
120
            self.random_state = np.random.randint(0, high=2 ** 32 - 2)
121
122
        random.seed(self.random_state + nth_process)
123
        np.random.seed(self.random_state + nth_process)
124
125
    def search(self, start_time, max_time, nth_process):
126
        start_time_search = time.time()
127
        self._initialize_search(nth_process)
128
129
        # loop to initialize N positions
130
        for nth_init in range(len(self.opt.init_positions)):
131
            start_time_iter = time.time()
132
            pos_new = self.opt.init_pos(nth_init)
133
134
            start_time_eval = time.time()
135
            score_new = self.cand.get_score(pos_new, nth_init)
136
            self.eval_times.append(time.time() - start_time_eval)
137
138
            self.opt.evaluate(score_new)
139
            self.iter_times.append(time.time() - start_time_iter)
140
141
        # loop to do the iterations
142
        for nth_iter in range(len(self.opt.init_positions), self.n_iter):
143
            start_time_iter = time.time()
144
            pos_new = self.opt.iterate(nth_iter)
145
146
            start_time_eval = time.time()
147
            score_new = self.cand.get_score(pos_new, nth_iter)
148
            self.eval_times.append(time.time() - start_time_eval)
149
150
            self.opt.evaluate(score_new)
151
            self.iter_times.append(time.time() - start_time_search)
152
153
            if self._time_exceeded(start_time, max_time):
154
                break
155
156
        self.p_bar.close_p_bar()
157
        self._save_results()
158
159
        return self.res
160
161