Completed
Push — master ( 91b7c0...d52c79 )
by Raphael
01:31
created

parse_arguments()   A

Complexity

Conditions 1

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
dl 0
loc 13
rs 9.4285
c 1
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
# This script is based on the launch of platoon.
4
5
6
from __future__ import print_function
7
import os
8
import time
9
import shlex
10
import argparse
11
import subprocess
12
import logging
13
logging.basicConfig(level=logging.INFO)
14
15
16
def parse_arguments():
17
    ap = argparse.ArgumentParser(description="Launch a multi-GPU expreiment")
18
    ap.add_argument('worker_path', help='Path of worker')
19
    ap.add_argument('gpu_list', nargs='+', type=str, help='The list of Theano GPU ids (Ex: gpu0, cuda1) the script will use. 1 GPU id = 1 worker launched.')
20
    ap.add_argument("--port", type=int, default=5567)
21
    ap.add_argument("--learning_rate", type=float, default=0.01)
22
    ap.add_argument("--start_halving_at", type=int, default=5)
23
    ap.add_argument("--end_at", type=int, default=10)
24
    ap.add_argument("--step_len", type=int, default=10)
25
    ap.add_argument("--valid_freq", type=int, default=1500)
26
    ap.add_argument("--easgd_alpha", default="auto")
27
    ap.add_argument('-w', '--workers-args', required=False, help='The arguments that will be passed to your workers. (Ex: -w="learning_rate=0.1")')
28
    return ap.parse_args()
29
30
31
def launch_process(is_server, args, device, path=""):
32
    print("Starting {0} on {1} ...".format("server" if is_server else "worker", device), end=' ')
33
34
    env = dict(os.environ)
35
    env['THEANO_FLAGS'] = '{},device={}'.format(env.get('THEANO_FLAGS', ''), device)
36
    if is_server:
37
        command = ["python",  "-u",  "-m", "deepy.multigpu.scheduled_server"]
38
    else:
39
        command = ["python", "-u", path]
40
    if not args is None:
41
        command += args
42
    process = subprocess.Popen(command, bufsize=0, env=env)
43
    print("Done")
44
    return process
45
46
if __name__ == '__main__':
47
    args = parse_arguments()
48
49
    process_map = {}
50
51
    easgd_alpha = args.easgd_alpha
52
    if easgd_alpha == "auto":
53
        easgd_alpha = 1.0 / len(args.gpu_list)
54
55
    controller_args_str = "--port {} --learning_rate {} --start_halving_at {} --end_at {} --step_len {} --valid_freq {} --easgd_alpha {}".format(
56
        args.port, args.learning_rate,
57
        args.start_halving_at, args.end_at,
58
        args.step_len, args.valid_freq,
59
        easgd_alpha
60
    )
61
    p = launch_process(True, shlex.split(controller_args_str), "cpu")
62
    process_map[p.pid] = ('scheduling server', p)
63
64
    for device in args.gpu_list:
65
        worker_process = launch_process(False, shlex.split(args.workers_args or ''), device, args.worker_path)
66
        process_map[worker_process.pid] = ("worker_{}".format(device),
67
                                           worker_process)
68
69
    print("\n### Waiting on experiment to finish ...")
70
71
    # Silly error handling but that will do for now.
72
    while process_map:
73
        pid, returncode = os.wait()
74
        if pid not in process_map:
75
            print("Recieved status for unknown process {}".format(pid))
76
77
        name, p = process_map[pid]
78
        del process_map[pid]
79
        print("{} terminated with return code: {}.".format(name, returncode))
80
        if returncode != 0:
81
            print("\nWARNING! An error has occurred.")
82
            while process_map:
83
                for name, p in list(process_map.values()):
84
                    try:
85
                        p.kill()
86
                    except OSError:
87
                        pass
88
                    if p.poll() is not None:
89
                        del process_map[p.pid]
90