Total Complexity | 7 |
Total Lines | 54 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | import warnings |
||
6 | |||
7 | def try_ray_import(): |
||
8 | try: |
||
9 | import ray |
||
10 | |||
11 | if ray.is_initialized(): |
||
12 | rayInit = True |
||
13 | else: |
||
14 | rayInit = False |
||
15 | except ImportError: |
||
16 | warnings.warn("failed to import ray", ImportWarning) |
||
17 | ray = None |
||
18 | rayInit = False |
||
19 | |||
20 | return ray, rayInit |
||
21 | |||
22 | |||
23 | def dist(optimizer_class, _main_args_, _opt_args_): |
||
24 | ray, rayInit = try_ray_import() |
||
25 | |||
26 | if rayInit: |
||
27 | dist_ray(optimizer_class, _main_args_, _opt_args_) |
||
28 | else: |
||
29 | dist_default(optimizer_class, _main_args_, _opt_args_) |
||
30 | |||
31 | |||
32 | def dist_default(optimizer_class, _main_args_, _opt_args_): |
||
33 | _optimizer_ = optimizer_class(_main_args_, _opt_args_) |
||
34 | params_results, pos_list, score_list = ( |
||
35 | _optimizer_.search() |
||
36 | ) |
||
37 | |||
38 | # print("params_results", params_results) |
||
39 | |||
40 | def dist_ray(optimizer_class, _main_args_, _opt_args_): |
||
41 | optimizer_class = ray.remote(optimizer_class) |
||
|
|||
42 | opts = [ |
||
43 | optimizer_class.remote(_main_args_, _opt_args_) |
||
44 | for job in range(_main_args_.n_jobs) |
||
45 | ] |
||
46 | searches = [ |
||
47 | opt.search.remote(job, rayInit=rayInit) for job, opt in enumerate(opts) |
||
48 | ] |
||
49 | params_results, pos_list, score_list = ray.get(searches)[0] |
||
50 | |||
51 | # print("params_results", params_results) |
||
52 | |||
53 | ray.shutdown() |