Completed
Push — master ( 5a91c7...ef4013 )
by Raphael
01:39
created

launch_process()   B

Complexity

Conditions 6

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 16
rs 8
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
# This script is based on the launch of platoon.
4
5
6
from __future__ import print_function
7
import os
8
import shlex
9
import argparse
10
import subprocess
11
import logging
12
logging.basicConfig(level=logging.INFO)
13
14
15
def parse_arguments():
16
    ap = argparse.ArgumentParser(description="Launch a multi-GPU expreiment")
17
    ap.add_argument('worker_path', help='Path of worker')
18
    ap.add_argument('gpu_list', nargs='+', type=str, help='The list of Theano GPU ids (Ex: gpu0, cuda1) the script will use. 1 GPU id = 1 worker launched.')
19
    ap.add_argument("--port", type=int, default=5567)
20
    ap.add_argument("--easgd_alpha", default="auto")
21
    ap.add_argument("--log", type=str, default=None)
22
    ap.add_argument('-w', '--workers-args', required=False, help='The arguments that will be passed to your workers. (Ex: -w="learning_rate=0.1")')
23
    return ap.parse_args()
24
25
26
def launch_process(is_server, args, device, path=""):
27
    print("Starting {0} on {1} ...".format("server" if is_server else "worker", device), end=' ')
28
29
    env = dict(os.environ)
30
    env['THEANO_FLAGS'] = '{},device={}'.format(env.get('THEANO_FLAGS', ''), device)
31
    if is_server:
32
        command = ["python",  "-u",  "-m", "deepy.multigpu.server"]
33
    else:
34
        command = ["python", "-u", path]
35
    if not args is None:
36
        command += args
37
    process = subprocess.Popen(command, bufsize=0, env=env,
38
                               stdout=subprocess.PIPE if not is_server else None,
39
                               stderr=subprocess.PIPE if not is_server else None)
40
    print("Done")
41
    return process
42
43
if __name__ == '__main__':
44
    args = parse_arguments()
45
46
    process_map = {}
47
48
    easgd_alpha = args.easgd_alpha
49
    if easgd_alpha == "auto":
50
        easgd_alpha = 1.0 / len(args.gpu_list)
51
52
    controller_args_str = "--port {} --easgd_alpha {} --log {}".format(
53
        args.port,
54
        easgd_alpha,
55
        args.log
56
    )
57
    p = launch_process(True, shlex.split(controller_args_str), "cpu")
58
    process_map[p.pid] = ('scheduling server', p)
59
60
    for device in args.gpu_list:
61
        worker_process = launch_process(False, shlex.split(args.workers_args or ''), device, args.worker_path)
62
        process_map[worker_process.pid] = ("worker_{}".format(device),
63
                                           worker_process)
64
65
    print("\n### Waiting on experiment to finish ...")
66
67
    # Silly error handling but that will do for now.
68
    while process_map:
69
        pid, returncode = os.wait()
70
        if pid not in process_map:
71
            print("Recieved status for unknown process {}".format(pid))
72
73
        name, p = process_map[pid]
74
        del process_map[pid]
75
        print("{} terminated with return code: {}.".format(name, returncode))
76
        if returncode != 0:
77
            print("\nWARNING! An error has occurred.")
78
            while process_map:
79
                for name, p in list(process_map.values()):
80
                    try:
81
                        p.kill()
82
                    except OSError:
83
                        pass
84
                    if p.poll() is not None:
85
                        del process_map[p.pid]
86