Passed
Push — master ( 4bb259...06915f )
by Simon
04:09
created

SearchProcess.__init__()   A

Complexity

Conditions 1

Size

Total Lines 34
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 30
nop 13
dl 0
loc 34
rs 9.16
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import time
6
import random
7
import numpy as np
8
import pandas as pd
9
10
from importlib import import_module
11
12
13
optimizer_dict = {
14
    "HillClimbing": "HillClimbingOptimizer",
15
    "StochasticHillClimbing": "StochasticHillClimbingOptimizer",
16
    "TabuSearch": "TabuOptimizer",
17
    "RandomSearch": "RandomSearchOptimizer",
18
    "RandomRestartHillClimbing": "RandomRestartHillClimbingOptimizer",
19
    "RandomAnnealing": "RandomAnnealingOptimizer",
20
    "SimulatedAnnealing": "SimulatedAnnealingOptimizer",
21
    "StochasticTunneling": "StochasticTunnelingOptimizer",
22
    "ParallelTempering": "ParallelTemperingOptimizer",
23
    "ParticleSwarm": "ParticleSwarmOptimizer",
24
    "EvolutionStrategy": "EvolutionStrategyOptimizer",
25
    "Bayesian": "BayesianOptimizer",
26
    "TPE": "TreeStructuredParzenEstimators",
27
    "DecisionTree": "DecisionTreeOptimizer",
28
}
29
30
31
class SearchProcess:
32
    def __init__(
33
        self,
34
        nth_process,
35
        p_bar,
36
        model,
37
        search_space,
38
        search_name,
39
        n_iter,
40
        training_data,
41
        optimizer,
42
        n_jobs,
43
        init_para,
44
        memory,
45
        random_state,
46
    ):
47
        self.nth_process = nth_process
48
        self.p_bar = p_bar
49
        self.model = model
50
        self.search_space = search_space
51
        self.n_iter = n_iter
52
        self.training_data = training_data
53
        self.optimizer = optimizer
54
        self.n_jobs = n_jobs
55
        self.init_para = init_para
56
        self.memory = memory
57
        self.random_state = random_state
58
59
        self._process_arguments()
60
61
        self.iter_times = []
62
        self.eval_times = []
63
64
        module = import_module("gradient_free_optimizers")
65
        self.opt_class = getattr(module, optimizer_dict[optimizer])
66
67
    def _time_exceeded(self, start_time, max_time):
68
        run_time = time.time() - start_time
69
        return max_time and run_time > max_time
70
71
    def _initialize_search(self, nth_process):
72
        self._set_random_seed(nth_process)
73
74
        self.p_bar.init_p_bar(nth_process, self.n_iter, self.model)
75
        init_positions = self.cand.init.set_start_pos(self.n_positions)
76
        self.opt = self.opt_class(init_positions, self.cand.space.dim, opt_para={})
77
78
    def _process_arguments(self):
79
        if isinstance(self.optimizer, dict):
80
            optimizer = list(self.optimizer.keys())[0]
81
            self.opt_para = self.optimizer[optimizer]
82
            self.optimizer = optimizer
83
84
            self.n_positions = self._get_n_positions()
85
        else:
86
            self.opt_para = {}
87
            self.n_positions = self._get_n_positions()
88
89
    def _get_n_positions(self):
90
        n_positions_strings = [
91
            "n_positions",
92
            "system_temperatures",
93
            "n_particles",
94
            "individuals",
95
        ]
96
97
        n_positions = 1
98
        for n_pos_name in n_positions_strings:
99
            if n_pos_name in list(self.opt_para.keys()):
100
                n_positions = self.opt_para[n_pos_name]
101
                if n_positions == "system_temperatures":
102
                    n_positions = len(n_positions)
103
104
        return n_positions
105
106
    def _save_results(self):
107
        self.res.eval_times = self.eval_times
108
        self.res.iter_times = self.iter_times
109
        self.res.memory_dict_new = self.cand.memory_dict_new
110
        self.res.para_best = self.cand.para_best
111
        self.res.score_best = self.cand.score_best
112
        self.res.model = self.model
113
114
    def _set_random_seed(self, nth_process):
115
        """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
116
        if self.random_state is None:
117
            self.random_state = np.random.randint(0, high=2 ** 32 - 2)
118
119
        print("self.random_state + nth_process", self.random_state + nth_process)
120
121
        random.seed(self.random_state + nth_process)
122
        np.random.seed(self.random_state + nth_process)
123
124
    def store_memory(self, memory):
125
        pass
126
127
    def print_best_para(self):
128
        self.verb.info.print_start_point()
129
130
    def search(self, start_time, max_time, nth_process):
131
        start_time_search = time.time()
132
        self._initialize_search(nth_process)
133
134
        # loop to initialize N positions
135
        for nth_init in range(len(self.opt.init_positions)):
136
            start_time_iter = time.time()
137
            pos_new = self.opt.init_pos(nth_init)
138
139
            start_time_eval = time.time()
140
            score_new = self.cand.get_score(pos_new, nth_init)
141
            self.eval_times.append(time.time() - start_time_eval)
142
143
            self.opt.evaluate(score_new)
144
            self.iter_times.append(time.time() - start_time_iter)
145
146
        # loop to do the iterations
147
        for nth_iter in range(len(self.opt.init_positions), self.n_iter):
148
            start_time_iter = time.time()
149
            pos_new = self.opt.iterate(nth_iter)
150
151
            start_time_eval = time.time()
152
            score_new = self.cand.get_score(pos_new, nth_iter)
153
            self.eval_times.append(time.time() - start_time_eval)
154
155
            self.opt.evaluate(score_new)
156
            self.iter_times.append(time.time() - start_time_search)
157
158
            if self._time_exceeded(start_time, max_time):
159
                break
160
161
        self.p_bar.close_p_bar()
162
        self._save_results()
163
164
        return self.res
165
166