Passed
Push — master ( 41483f...5037dc )
by Simon
01:18
created

InitSearchPosition._create_warm_start()   A

Complexity

Conditions 3

Size

Total Lines 20
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 20
rs 9.9
c 0
b 0
f 0
cc 3
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from ..util import sort_for_best
6
import numpy as np
7
8
9
class InitSearchPosition:
10
    def __init__(self, space, model, _main_args_):
11
        self._space_ = space
12
        self._model_ = model
13
        self.init_config = _main_args_.init_config
14
        self.X = _main_args_.X
15
        self.y = _main_args_.y
16
17
    def _warm_start(self):
18
        pos = []
19
20
        for hyperpara_name in self._space_.search_space.keys():
21
            if hyperpara_name not in list(self._space_.init_para.keys()):
22
                # print(hyperpara_name, "not in warm_start selecting random scalar")
23
                search_position = self._space_.get_random_pos_scalar(hyperpara_name)
24
25
            else:
26
                search_position = self._space_.search_space[hyperpara_name].index(
27
                    self._space_.init_para[hyperpara_name]
28
                )
29
            pos.append(search_position)
30
31
        return np.array(pos)
32
33
    def _set_start_pos(self):
34
        if self._space_.init_type == "warm_start":
35
            pos = self._warm_start()
36
        elif self._space_.init_type == "scatter_init":
37
            pos = self._scatter_init()
38
        else:
39
            pos = self._space_.get_random_pos()
40
41
        return pos
42
43
    def _warm_start_scatter_init(self, nth_process):
44
        if self.n_warm_start_keys > nth_process:
45
            pos = self._create_warm_start()
46
        else:
47
            pos = self._scatter_init()
48
49
        return pos
50
51
    def _scatter_init(self):
52
        pos_list = []
53
        for _ in range(self._space_.init_para["scatter_init"]):
54
            pos = self._space_.get_random_pos()
55
            pos_list.append(pos)
56
57
        pos_best_list, score_best_list = self._scatter_train(pos_list)
58
        pos_best_sorted, _ = sort_for_best(pos_best_list, score_best_list)
59
60
        return pos_best_sorted[0]
61
62
    def _scatter_train(self, pos_list):
63
        pos_best_list = []
64
        score_best_list = []
65
66
        X, y = self._get_random_sample(self.X, self.y)
67
68
        for pos in pos_list:
69
            para = self._space_.pos2para(pos)
70
            score, eval_time, model = self._model_.train_model(para)
71
72
            pos_best_list.append(pos)
73
            score_best_list.append(score)
74
75
        return pos_best_list, score_best_list
76
77
    def _get_random_sample(self, X, y):
78
        if isinstance(X, np.ndarray) and isinstance(y, np.ndarray):
79
            n_samples = int(X.shape[0] / self._space_.init_para["scatter_init"])
80
81
            idx = np.random.choice(np.arange(len(X)), n_samples, replace=False)
82
83
            X_sample = X[idx]
84
            y_sample = y[idx]
85
86
            return X_sample, y_sample
87