Passed
Push — master ( b478a2...16a37b )
by Simon
01:57 queued 10s
created

InitSearchPosition._scatter_train()   A

Complexity

Conditions 2

Size

Total Lines 14
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 2
dl 0
loc 14
rs 9.9
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from .util import sort_for_best
6
import numpy as np
7
8
9
class InitSearchPosition:
10
    def __init__(self, space, model, _main_args_):
11
        self._space_ = space
12
        self._model_ = model
13
        self.init_config = _main_args_.init_config
14
        self.X = _main_args_.X
15
        self.y = _main_args_.y
16
17
    def _warm_start(self):
18
        pos = []
19
20
        for hyperpara_name in self._space_.search_space.keys():
21
            if hyperpara_name not in list(self._space_.init_para.keys()):
22
                search_position = self._space_.get_random_pos_scalar(hyperpara_name)
23
24
            else:
25
                search_position = self._space_.search_space[hyperpara_name].index(
26
                    self._space_.init_para[hyperpara_name]
27
                )
28
            pos.append(search_position)
29
30
        return np.array(pos)
31
32
    def _set_start_pos(self, _info_):
33
        if self._space_.init_type == "warm_start":
34
            _info_.warm_start()
35
            pos = self._warm_start()
36
        elif self._space_.init_type == "scatter_init":
37
            _info_.scatter_start()
38
            pos = self._scatter_init()
39
        else:
40
            _info_.random_start()
41
            pos = self._space_.get_random_pos()
42
43
        return pos
44
45
    def _scatter_init(self):
46
        pos_list = []
47
        for _ in range(self._space_.init_para["scatter_init"]):
48
            pos = self._space_.get_random_pos()
49
            pos_list.append(pos)
50
51
        pos_best_list, score_best_list = self._scatter_train(pos_list)
52
        pos_best_sorted, _ = sort_for_best(pos_best_list, score_best_list)
53
54
        return pos_best_sorted[0]
55
56
    def _scatter_train(self, pos_list):
57
        pos_best_list = []
58
        score_best_list = []
59
60
        X, y = self._get_random_sample(self.X, self.y)
61
62
        for pos in pos_list:
63
            para = self._space_.pos2para(pos)
64
            score, eval_time, model = self._model_.train_model(para)
65
66
            pos_best_list.append(pos)
67
            score_best_list.append(score)
68
69
        return pos_best_list, score_best_list
70
71
    def _get_random_sample(self, X, y):
72
        if isinstance(X, np.ndarray) and isinstance(y, np.ndarray):
73
            n_samples = int(X.shape[0] / self._space_.init_para["scatter_init"])
74
75
            idx = np.random.choice(np.arange(len(X)), n_samples, replace=False)
76
77
            X_sample = X[idx]
78
            y_sample = y[idx]
79
80
            return X_sample, y_sample
81