Passed
Push — master ( 06915f...21ec56 )
by Simon
01:13
created

hyperactive.search.Search._get_results()   A

Complexity

Conditions 2

Size

Total Lines 30
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 24
nop 2
dl 0
loc 30
rs 9.304
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import numpy as np
7
import pandas as pd
8
9
from multiprocessing import Pool
10
from importlib import import_module
11
12
13
class Search:
14
    def __init__(self, function_parameter, search_processes, verb):
15
        self.function_parameter = function_parameter
16
        self.search_processes = search_processes
17
        self.verb = verb
18
19
        self.n_processes = len(search_processes)
20
        self._n_process_range = range(0, self.n_processes)
21
22
        self.results = {}
23
        self.eval_times = {}
24
        self.iter_times = {}
25
        self.best_scores = {}
26
        self.pos_list = {}
27
        self.score_list = {}
28
        self.position_results = {}
29
30
    def _get_results(self, results_list):
31
        position_results_dict = {}
32
33
        self.eval_times_dict = {}
34
        self.iter_times_dict = {}
35
        self.para_best_dict = {}
36
        self.score_best_dict = {}
37
        self.memory_dict_new = {}
38
39
        for results in results_list:
40
            search_name = results.search_name
41
42
            self.eval_times_dict[search_name] = results.eval_times
43
            self.iter_times_dict[search_name] = results.iter_times
44
            self.para_best_dict[search_name] = results.para_best
45
            self.score_best_dict[search_name] = results.score_best
46
            self.memory_dict_new[search_name] = results.memory_dict_new
47
            self.position_results[search_name] = self._memory_dict2dataframe(
48
                results.memory_dict_new, results.search_space
49
            )
50
51
            print(
52
                "Process",
53
                results.nth_process,
54
                "->",
55
                results.model.__name__,
56
                "search results:",
57
            )
58
            print("best parameter =", results.para_best)
59
            print("best score     =", results.score_best, "\n")
60
61
    def _run_job(self, nth_process):
62
        self.process = self.search_processes[nth_process]
63
        return self.process.search(self.start_time, self.max_time, nth_process)
64
65
    def _run_multiple_jobs(self):
66
        """Wrapper for the parallel search. Passes integer that corresponds to process number"""
67
        pool = Pool(self.n_processes)
68
        results_list = pool.map(self._run_job, self._n_process_range)
69
70
        for _ in range(int(self.n_processes / 2) + 2):
71
            print("\n")  # make room in cmd for prints
72
73
        return results_list
74
75
    def _memory_dict2dataframe(self, memory_dict, search_space):
76
        columns = list(search_space.keys())
77
78
        if not bool(memory_dict):
79
            return pd.DataFrame([], columns=columns)
80
81
        pos_tuple_list = list(memory_dict.keys())
82
        result_list = list(memory_dict.values())
83
84
        results_df = pd.DataFrame(result_list)
85
        np_pos = np.array(pos_tuple_list)
86
87
        pd_pos = pd.DataFrame(np_pos, columns=columns)
88
        dataframe = pd.concat([pd_pos, results_df], axis=1)
89
90
        return dataframe
91
92
    def _run(self, start_time, max_time):
93
        self.start_time = start_time
94
        self.max_time = max_time
95
96
        if len(self.search_processes) == 1:
97
            results_list = [self._run_job(0)]
98
        else:
99
            results_list = self._run_multiple_jobs()
100
101
        self._get_results(results_list)
102
        self._save_memory(results_list)
103
104
    def run(self, start_time, max_time):
105
        self._run(start_time, max_time)
106
107
    def _save_memory(self, results):
108
        for result in results:
109
            if result.memory == "long":
110
                result.save_long_term_memory()
111
112