Passed
Push — master ( cba5ed...7e35c2 )
by Simon
04:17
created

hyperactive.data_tools.ltm_wrapper   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 61
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 61
rs 10
c 0
b 0
f 0
wmc 10

7 Methods

Rating   Name   Duplication   Size   Complexity  
A LongTermMemory.save_on_finish() 0 2 1
A LongTermMemory._no_ltm_wrapper() 0 2 1
A LongTermMemory.load() 0 2 1
A LongTermMemory.save_on_iteration() 0 2 1
A LongTermMemory._ltm_wrapper() 0 11 2
A LongTermMemory.__init__() 0 18 3
A LongTermMemory.init_study() 0 7 1
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_
7
8
9
class LongTermMemory:
10
    def __init__(
11
        self,
12
        study_id,
13
        model_id,
14
        save_on="finish",
15
    ):
16
        self.study_id = study_id
17
        self.model_id = model_id
18
19
        path = os.path.realpath(__file__).rsplit("/", 1)[0] + "/"
20
21
        self.ltm_origin = _LongTermMemory_(path=".")
22
        self.save_on = save_on
23
24
        if save_on == "finish":
25
            self.ltm_obj_func_wrapper = self._no_ltm_wrapper
26
        elif save_on == "iteration":
27
            self.ltm_obj_func_wrapper = self._ltm_wrapper
28
29
    def _no_ltm_wrapper(self, results, para):
30
        pass
31
32
    def _ltm_wrapper(self, results, para):
33
        if isinstance(results, tuple):
34
            score = results[0]
35
            results_dict = results[1]
36
        else:
37
            score = results
38
            results_dict = {}
39
40
        results_dict["score"] = score
41
        ltm_dict = {**para, **results_dict}
42
        self.save_on_iteration(ltm_dict, self.nth_process)
43
44
    def init_study(self, objective_function, search_space, nth_process):
45
        self.nth_process = nth_process
46
        self.ltm_origin.init_study(
47
            objective_function,
48
            search_space,
49
            study_id=self.study_id,
50
            model_id=self.model_id,
51
        )
52
53
    def load(self):
54
        return self.ltm_origin.load()
55
56
    def save_on_finish(self, dataframe):
57
        self.ltm_origin.save_on_finish(dataframe)
58
59
    def save_on_iteration(self, data_dict, nth_process):
60
        self.ltm_origin.save_on_iteration(data_dict, nth_process)
61