Passed
Push — master ( d03f6d...d5da96 )
by Simon
01:24
created

hyperactive.distribution.try_ray_import()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 0
dl 0
loc 14
rs 9.85
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import warnings
6
7
8
def try_ray_import():
9
    try:
10
        import ray
11
12
        if ray.is_initialized():
13
            rayInit = True
14
        else:
15
            rayInit = False
16
    except ImportError:
17
        warnings.warn("failed to import ray", ImportWarning)
18
        ray = None
19
        rayInit = False
20
21
    return ray, rayInit
22
23
24
class Distribution:
25
    def dist(self, optimizer_class, _main_args_, _opt_args_):
26
        ray, rayInit = try_ray_import()
27
28
        if rayInit:
29
            self.dist_ray(optimizer_class, _main_args_, _opt_args_, ray)
30
        else:
31
            self.dist_default(optimizer_class, _main_args_, _opt_args_)
32
33
    def dist_default(self, optimizer_class, _main_args_, _opt_args_):
34
        _optimizer_ = optimizer_class(_main_args_, _opt_args_)
35
        self.results, self.pos, self.scores, self.eval_times, self.opt_times = (
36
            _optimizer_.search()
37
        )
38
39
    def dist_ray(self, optimizer_class, _main_args_, _opt_args_, ray):
40
        optimizer_class = ray.remote(optimizer_class)
41
        opts = [
42
            optimizer_class.remote(_main_args_, _opt_args_)
43
            for job in range(_main_args_.n_jobs)
44
        ]
45
        searches = [
46
            opt.search.remote(job, rayInit=True) for job, opt in enumerate(opts)
47
        ]
48
        self.results, self.pos, self.scores, self.eval_times, self.opt_times = ray.get(
49
            searches
50
        )[0]
51
52
        ray.shutdown()
53