1
|
|
|
# Author: Simon Blanke |
2
|
|
|
# Email: [email protected] |
3
|
|
|
# License: MIT License |
4
|
|
|
|
5
|
|
|
import warnings |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
def try_ray_import(): |
9
|
|
|
try: |
10
|
|
|
import ray |
11
|
|
|
|
12
|
|
|
if ray.is_initialized(): |
13
|
|
|
rayInit = True |
14
|
|
|
else: |
15
|
|
|
rayInit = False |
16
|
|
|
except ImportError: |
17
|
|
|
warnings.warn("failed to import ray", ImportWarning) |
18
|
|
|
ray = None |
19
|
|
|
rayInit = False |
20
|
|
|
|
21
|
|
|
return ray, rayInit |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
class Distribution: |
25
|
|
|
def dist(self, optimizer_class, _main_args_, _opt_args_): |
26
|
|
|
ray, rayInit = try_ray_import() |
27
|
|
|
|
28
|
|
|
if rayInit: |
29
|
|
|
self.dist_ray(optimizer_class, _main_args_, _opt_args_, ray) |
30
|
|
|
else: |
31
|
|
|
self.dist_default(optimizer_class, _main_args_, _opt_args_) |
32
|
|
|
|
33
|
|
|
def dist_default(self, optimizer_class, _main_args_, _opt_args_): |
34
|
|
|
_optimizer_ = optimizer_class(_main_args_, _opt_args_) |
35
|
|
|
self.results, self.pos, self.scores, self.eval_times, self.opt_times = ( |
36
|
|
|
_optimizer_.search() |
37
|
|
|
) |
38
|
|
|
|
39
|
|
|
def dist_ray(self, optimizer_class, _main_args_, _opt_args_, ray): |
40
|
|
|
optimizer_class = ray.remote(optimizer_class) |
41
|
|
|
opts = [ |
42
|
|
|
optimizer_class.remote(_main_args_, _opt_args_) |
43
|
|
|
for job in range(_main_args_.n_jobs) |
44
|
|
|
] |
45
|
|
|
searches = [ |
46
|
|
|
opt.search.remote(job, rayInit=True) for job, opt in enumerate(opts) |
47
|
|
|
] |
48
|
|
|
self.results, self.pos, self.scores, self.eval_times, self.opt_times = ray.get( |
49
|
|
|
searches |
50
|
|
|
)[0] |
51
|
|
|
|
52
|
|
|
ray.shutdown() |
53
|
|
|
|