MultiGPUTrainer   A
last analyzed

Complexity

Total Complexity 32

Size/Duplication

Total Lines 147
Duplicated Lines 0 %

Importance

Changes 11
Bugs 0 Features 0
Metric Value
c 11
b 0
f 0
dl 0
loc 147
rs 9.6
wmc 32

5 Methods

Rating   Name   Duplication   Size   Complexity  
F train() 0 101 23
A create_param_map() 0 5 2
A fix_costs() 0 2 2
A __init__() 0 22 2
A sync_hyperparams() 0 6 3
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import time
5
import os
6
from collections import OrderedDict
7
import numpy as np
8
9
from deepy.trainers import GeneralNeuralTrainer
10
11
import logging
12
13
class MultiGPUTrainer(GeneralNeuralTrainer):
14
    """
15
    General neural network trainer.
16
    """
17
18
19
    def __init__(self,
20
                 network, config=None, method='sgd',
21
                 server_port=5567,
22
                 start_halving_at=6, end_at=10, sync_freq=3,
23
                 valid_freq=1500, learning_rate=None, halving_freq=1,
24
                 using_easgd=True
25
                 ):
26
        super(MultiGPUTrainer, self).__init__(network, method, config)
27
        self._report_time = False
28
        self._port = server_port
29
        self.logger = logging.getLogger('MultiGPUTrainingWorker')
30
        self.epoch = 0
31
        self._using_easgd = using_easgd
32
        if not learning_rate:
33
            learning_rate = float(self.config.learning_rate.get_value())
34
        self._schedule_params = {
35
            'learning_rate': learning_rate,
36
            'start_halving_at': start_halving_at,
37
            'end_at': end_at,
38
            'sync_freq': sync_freq,
39
            'valid_freq': valid_freq,
40
            'halving_freq': halving_freq
41
        }
42
43
    def create_param_map(self):
44
        param_map = OrderedDict()
45
        for i, param in enumerate(self.training_params()):
46
            param_map["param_{}".format(i)] = param
47
        return param_map
48
49
    def sync_hyperparams(self, param_map):
50
        self.logger.info("(proc {}) sync hyperparameters".format(os.getpid()))
51
        if 'epoch' in param_map:
52
            self.epoch = param_map['epoch']
53
        if 'learning_rate' in param_map:
54
            self.config.learning_rate.set_value(param_map['learning_rate'])
55
56
    def fix_costs(self):
57
        self.last_run_costs = [(a, float(b)) for (a,b) in self.last_run_costs]
58
59
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
60
        """
61
        Train the model in multi-GPU environment.
62
        """
63
        from platoon.channel import Worker
64
        from platoon.param_sync import EASGD, ASGD
65
        server_port = self._port
66
        param_map = self.create_param_map()
67
        # Initialize the worker
68
        worker = Worker(control_port=server_port)
69
        if self.config.learning_rate:
70
            worker.send_req({'init_schedule': self._schedule_params})
71
        self.sync_hyperparams(worker.send_req('sync_hyperparams')['sync_hyperparams'])
72
        easgd_alpha = worker.send_req('get_easgd_alpha')
73
        if self._using_easgd:
74
            self.logger.info("using EASGD with alpha={}".format(easgd_alpha))
75
        else:
76
            self.logger.info("using ASGD rule")
77
        rule = EASGD(easgd_alpha) if self._using_easgd else ASGD()
78
        worker.init_shared_params(param_map.values(), param_sync_rule=rule)
79
        worker.send_req({
80
            "set_names": None,
81
            "training_names": self.training_names,
82
            "evaluation_names": self.evaluation_names
83
        })
84
        # Load all training batches, consume vast memory here
85
        self.logger.info("started process {}".format(os.getpid()))
86
        self.logger.info("(proc {}) load training data".format(os.getpid()))
87
        train_batches = list(train_set)
88
        network_callback = bool(self.network.training_callbacks)
89
        trainer_callback = bool(self._iter_controllers)
90
        # Start from valid, so the performance when a worked join can be known
91
        worker.copy_to_local()
92
        if valid_set:
93
            self._run_valid(self.epoch, valid_set, dry_run=True)
94
            self.fix_costs()
95
        worker.send_req({
96
            "valid_done": None,
97
            "valid_costs": self.last_run_costs,
98
            "auto_save": self.config.auto_save
99
        })
100
        worker.copy_to_local()
101
        # Begin the loop
102
        while True:
103
            resp = worker.send_req('next')
104
            if resp == 'stop':
105
                break
106
            elif resp == 'wait':
107
                time.sleep(1)
108
            elif resp == 'get_num_batches':
109
                worker.send_req({'get_num_batches_done': len(train_batches)})
110
            elif 'eval' in resp:
111
                self.best_cost = resp['best_valid_cost']
112
                worker.copy_to_local()
113
                valid_costs = None
114
                test_costs = None
115
                if valid_set:
116
                    self._run_valid(self.epoch, valid_set)
117
                    self.fix_costs()
118
                    valid_costs = self.last_run_costs
119
                if test_set:
120
                    self._run_test(self.epoch, test_set)
121
                    self.fix_costs()
122
                    test_costs = self.last_run_costs
123
                worker.send_req({
124
                    "eval_done": None,
125
                    "valid_costs": valid_costs,
126
                    "test_costs": test_costs,
127
                    "auto_save": self.config.auto_save
128
                })
129
            elif 'valid' in resp:
130
                self.best_cost = resp['best_valid_cost']
131
                worker.copy_to_local()
132
                if valid_set:
133
                    self._run_valid(self.epoch, valid_set, dry_run=True)
134
                    self.fix_costs()
135
                worker.send_req({
136
                    "valid_done": None,
137
                    "valid_costs": self.last_run_costs,
138
                    "auto_save": self.config.auto_save
139
                })
140
            elif 'train' in resp:
141
                batch_ids = resp['train']
142
                batch_costs = [[] for _ in self.training_names]
143
                for batch_id in batch_ids:
144
                    x = train_batches[batch_id]
145
                    cost_x = self.learn(*x)
146
                    for i, cost in enumerate(cost_x):
147
                        batch_costs[i].append(cost)
148
                    self.last_cost = cost_x[0]
149
                if network_callback:
150
                    self.network.training_callback()
151
                if trainer_callback:
152
                    for func in self._iter_controllers:
153
                        func(self)
154
                worker.sync_params(synchronous=True)
155
                worker.send_req({'train_done': None, 'costs': [float(np.mean(c)) for c in batch_costs]})
156
            elif 'sync_hyperparams' in resp:
157
                self.sync_hyperparams(resp['sync_hyperparams'])
158
        worker.close()
159
        return []
160