Passed
Push — master ( 7b482c...5d2e95 )
by Simon
08:24
created

hyperactive.process_arguments.stop_warnings()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 0
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
import random
7
import numpy as np
8
import multiprocessing
9
10
from .checks import check_hyperactive_para, check_search_para
11
12
13
def stop_warnings():
14
    # because sklearn warnings are annoying when they appear 100 times
15
    def warn(*args, **kwargs):
16
        pass
17
18
    import warnings
19
20
    warnings.warn = warn
21
22
23
class ProcessArguments:
24
    def __init__(self, args, kwargs, random_state):
25
        self.kwargs = kwargs
26
        self._set_default()
27
        self._add_args2kwargs(args)
28
29
        self.function_parameter = self.kwargs["function_parameter"]
30
        self.search_space = self.kwargs["search_space"]
31
        self.optimizer = self.kwargs["optimizer"]
32
        self.random_state = random_state
33
        self.n_jobs = self.kwargs["n_jobs"]
34
        self.init_para = self.kwargs["init_para"]
35
36
        self.set_n_jobs()
37
38
        if isinstance(self.optimizer, dict):
39
            optimizer = list(self.optimizer.keys())[0]
40
            self.opt_para = self.optimizer[optimizer]
41
            self.optimizer = optimizer
42
43
            self.n_positions = self._get_n_positions()
44
            print("n_positions", self.n_positions)
45
        else:
46
            self.opt_para = {}
47
            self.n_positions = self._get_n_positions()
48
49
    def _get_n_positions(self):
50
        n_positions_strings = [
51
            "n_positions",
52
            "system_temperatures",
53
            "n_particles",
54
            "individuals",
55
        ]
56
57
        n_positions = 1
58
        for n_pos_name in n_positions_strings:
59
            if n_pos_name in list(self.opt_para.keys()):
60
                n_positions = self.opt_para[n_pos_name]
61
                if n_positions == "system_temperatures":
62
                    n_positions = len(n_positions)
63
64
        return n_positions
65
66
    def set_n_jobs(self):
67
        """Sets the number of jobs to run in parallel"""
68
        num_cores = multiprocessing.cpu_count()
69
        if self.n_jobs == -1 or self.n_jobs > num_cores:
70
            self.n_jobs = num_cores
71
72
    def get_process_para(self):
73
        pass
74
75
    def _check_parameter(kwargs):
76
        pass
77
78
    def _add_args2kwargs(self, args):
79
        for arg in args:
80
            if callable(arg):
81
                self.kwargs["objective_function"] = arg
82
            elif isinstance(arg, dict):
83
                self.kwargs["search_space"] = arg
84
85
    def set_random_seed(self, thread):
86
        """Sets the random seed separately for each thread (to avoid getting the same results in each thread)"""
87
        if self.random_state is None:
88
            self.random_state = np.random.randint(0, high=2 ** 32 - 2)
89
90
        random.seed(self.random_state + thread)
91
        np.random.seed(self.random_state + thread)
92
93
    def _set_default(self):
94
        self.kwargs.setdefault("function_parameter", None)
95
        self.kwargs.setdefault("memory", None)
96
        self.kwargs.setdefault("optimizer", "RandomSearch")
97
        self.kwargs.setdefault("n_iter", 10)
98
        self.kwargs.setdefault("n_jobs", 1)
99
        self.kwargs.setdefault("init_para", [])
100
        self.kwargs.setdefault("distribution", None)
101