Passed
Push — master ( 626f23...be3b1e )
by Simon
01:48
created

EvolutionStrategyOptimizer.__init__()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 6
dl 0
loc 16
rs 9.75
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import random
6
import numpy as np
7
8
from .base_population_optimizer import BasePopulationOptimizer
9
from ...search import Search
10
from ._individual import Individual
11
12
13
class EvolutionStrategyOptimizer(BasePopulationOptimizer, Search):
14
    name = "Evolution Strategý"
15
16
    def __init__(
17
        self,
18
        *args,
19
        population=10,
20
        mutation_rate=0.7,
21
        crossover_rate=0.3,
22
        **kwargs,
23
    ):
24
        super().__init__(*args, **kwargs)
25
26
        self.population = population
27
        self.mutation_rate = mutation_rate
28
        self.crossover_rate = crossover_rate
29
30
        self.individuals = self._create_population(Individual)
31
        self.optimizers = self.individuals
32
33
    def _random_cross(self, array_list):
34
        n_arrays = len(array_list)
35
        size = array_list[0].size
36
37
        choice = [True, False]
38
        if size > 2:
39
            add_choice = np.random.randint(n_arrays, size=size - 2).astype(bool)
40
            choice += list(add_choice)
41
42
        cross_array = np.choose(choice, array_list)
43
        return cross_array
44
45
    def _sort_best(self):
46
        scores_list = []
47
        for ind in self.individuals:
48
            scores_list.append(ind.score_current)
49
50
        scores_np = np.array(scores_list)
51
        idx_sorted_ind = list(scores_np.argsort()[::-1])
52
53
        return [self.individuals[idx] for idx in idx_sorted_ind]
54
55
    def _cross(self):
56
        if len(self.individuals) > 2:
57
            rnd_int2 = random.choice(
58
                [i for i in range(0, self.n_ind - 1) if i not in [self.rnd_int]]
59
            )
60
        else:
61
            rnd_int2 = random.choice(
62
                [i for i in range(0, self.n_ind) if i not in [self.rnd_int]]
63
            )
64
65
        p_sec = self.ind_sorted[rnd_int2]
66
        p_worst = self.ind_sorted[-1]
67
68
        two_best_pos = [self.p_current.pos_current, p_sec.pos_current]
69
        pos_new = self._random_cross(two_best_pos)
70
71
        self.p_current = p_worst
72
        p_worst.pos_new = pos_new
73
74
        return pos_new
75
76
    def init_pos(self, pos):
77
        nth_pop = self.nth_iter % len(self.individuals)
78
79
        self.p_current = self.individuals[nth_pop]
80
        self.p_current.init_pos(pos)
81
82
    def iterate(self):
83
        self.n_ind = len(self.individuals)
84
85
        if self.n_ind == 1:
86
            self.p_current = self.individuals[0]
87
            return self.p_current.iterate()
88
89
        self.ind_sorted = self._sort_best()
90
        self.rnd_int = random.randint(0, len(self.ind_sorted) - 1)
91
        self.p_current = self.ind_sorted[self.rnd_int]
92
93
        total_rate = self.mutation_rate + self.crossover_rate
94
        rand = np.random.uniform(low=0, high=total_rate)
95
96
        if rand <= self.mutation_rate:
97
            return self.p_current.iterate()
98
        else:
99
            return self._cross()
100
101
    def evaluate(self, score_new):
102
        self.p_current.evaluate(score_new)
103