Passed
Push — master ( cba5ed...7e35c2 )
by Simon
04:17
created

LongTermMemory.init_study()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 4
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_
7
8
9
class LongTermMemory:
10
    def __init__(
11
        self,
12
        study_id,
13
        model_id,
14
        save_on="finish",
15
    ):
16
        self.study_id = study_id
17
        self.model_id = model_id
18
19
        path = os.path.realpath(__file__).rsplit("/", 1)[0] + "/"
20
21
        self.ltm_origin = _LongTermMemory_(path=".")
22
        self.save_on = save_on
23
24
        if save_on == "finish":
25
            self.ltm_obj_func_wrapper = self._no_ltm_wrapper
26
        elif save_on == "iteration":
27
            self.ltm_obj_func_wrapper = self._ltm_wrapper
28
29
    def _no_ltm_wrapper(self, results, para):
30
        pass
31
32
    def _ltm_wrapper(self, results, para):
33
        if isinstance(results, tuple):
34
            score = results[0]
35
            results_dict = results[1]
36
        else:
37
            score = results
38
            results_dict = {}
39
40
        results_dict["score"] = score
41
        ltm_dict = {**para, **results_dict}
42
        self.save_on_iteration(ltm_dict, self.nth_process)
43
44
    def init_study(self, objective_function, search_space, nth_process):
45
        self.nth_process = nth_process
46
        self.ltm_origin.init_study(
47
            objective_function,
48
            search_space,
49
            study_id=self.study_id,
50
            model_id=self.model_id,
51
        )
52
53
    def load(self):
54
        return self.ltm_origin.load()
55
56
    def save_on_finish(self, dataframe):
57
        self.ltm_origin.save_on_finish(dataframe)
58
59
    def save_on_iteration(self, data_dict, nth_process):
60
        self.ltm_origin.save_on_iteration(data_dict, nth_process)
61