Passed
Push — master ( 051bec...b89548 )
by Simon
01:31
created

LongTermMemory.save_on_iteration()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
from hyperactive_long_term_memory import LongTermMemory as _LongTermMemory_
7
8
9
class LongTermMemory:
10
    def __init__(
11
        self,
12
        model_id,
13
        experiment_id="default",
14
        save_on="finish",
15
    ):
16
17
        path, _ = os.path.realpath(__file__).rsplit("/", 1)
18
        path = path + "/"
19
20
        self.ltm_origin = _LongTermMemory_(
21
            model_id=model_id, experiment_id=experiment_id, path="."
22
        )
23
        self.save_on = save_on
24
25
        if save_on == "finish":
26
            self.ltm_obj_func_wrapper = self._no_ltm_wrapper
27
        elif save_on == "iteration":
28
            self.ltm_obj_func_wrapper = self._ltm_wrapper
29
30
    def _no_ltm_wrapper(self, results, para):
31
        pass
32
33
    def _ltm_wrapper(self, results, para):
34
        if isinstance(results, tuple):
35
            score = results[0]
36
            results_dict = results[1]
37
        else:
38
            score = results
39
            results_dict = {}
40
41
        results_dict["score"] = score
42
        ltm_dict = {**para, **results_dict}
43
        self.save_on_iteration(ltm_dict)
44
45
    def clean_files(self):
46
        self.ltm_origin.clean_files()
47
48
    def init_data_types(self, search_space):
49
        self.ltm_origin.init_data_types(search_space)
50
51
    def load(self):
52
        return self.ltm_origin.load()
53
54
    def save_on_finish(self, dataframe):
55
        self.ltm_origin.save_on_finish(dataframe)
56
57
    def save_on_iteration(self, data_dict):
58
        self.ltm_origin.save_on_iteration(data_dict)
59