Completed
Push — master ( 872a28...3bf564 )
by Simon
04:34 queued 11s
created

InitSearchPosition._create_warm_start()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 13
rs 10
c 0
b 0
f 0
cc 2
nop 2
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
from ..util import sort_for_best
6
import numpy as np
7
8
9
class InitSearchPosition:
10
    def __init__(self, space, model, warm_start, scatter_init):
11
        self._space_ = space
12
        self._model_ = model
13
        self.warm_start = warm_start
14
        self.scatter_init = scatter_init
15
16
        if self.warm_start:
17
            self.n_warm_start_keys = len(list(self.warm_start.keys()))
18
        else:
19
            self.n_warm_start_keys = 0
20
21
    def _create_warm_start(self, nth_process):
22
        pos = []
23
24
        for layer_key in self._space_.para_space.keys():
25
            layer_str, para_str = layer_key.rsplit(".", 1)
26
27
            search_position = self._space_.para_space[layer_key].index(
28
                *self.warm_start[layer_str][para_str]
29
            )
30
31
            pos.append(search_position)
32
33
        return np.array(pos)
34
35
    def _set_start_pos(self, nth_process, X, y):
36
        if self.warm_start and self.scatter_init:
37
            pos = self._warm_start_scatter_init(nth_process, X, y)
38
        elif self.warm_start:
39
            pos = self._warm_start(nth_process)
40
        elif self.scatter_init:
41
            pos = self._scatter_init(nth_process, X, y)
42
        else:
43
            pos = self._space_.get_random_pos()
44
45
        return pos
46
47
    def _warm_start_scatter_init(self, nth_process, X, y):
48
        if self.n_warm_start_keys > nth_process:
49
            pos = self._create_warm_start(nth_process)
50
        else:
51
            pos = self._scatter_init(nth_process, X, y)
52
53
        return pos
54
55
    def _warm_start(self, nth_process):
56
        if self.n_warm_start_keys > nth_process:
57
            pos = self._create_warm_start(nth_process)
58
        else:
59
            pos = self._space_.get_random_pos()
60
61
        return pos
62
63
    def _scatter_init(self, nth_process, X, y):
64
        pos_list = []
65
        for _ in range(self.scatter_init):
66
            pos = self._space_.get_random_pos()
67
            pos_list.append(pos)
68
69
        pos_best_list, score_best_list = self._scatter_train(X, y, pos_list)
70
71
        pos_best_sorted, _ = sort_for_best(pos_best_list, score_best_list)
72
73
        nth_best_pos = nth_process - self.n_warm_start_keys
74
75
        return pos_best_sorted[nth_best_pos]
76
77
    def _scatter_train(self, X, y, pos_list):
78
        pos_best_list = []
79
        score_best_list = []
80
81
        X, y = self._get_random_sample(X, y)
82
83
        for pos in pos_list:
84
            para = self._space_.pos2para(pos)
85
            score, _ = self._model_.train_model(para, X, y)
86
87
            pos_best_list.append(pos)
88
            score_best_list.append(score)
89
90
        return pos_best_list, score_best_list
91
92
    def _get_random_sample(self, X, y):
93
        if isinstance(X, np.ndarray) and isinstance(y, np.ndarray):
94
            n_samples = int(X.shape[0] / self.scatter_init)
95
96
            idx = np.random.choice(np.arange(len(X)), n_samples, replace=False)
97
98
            X_sample = X[idx]
99
            y_sample = y[idx]
100
101
            return X_sample, y_sample
102