Completed
Push — master ( f83d60...5a91c7 )
by Raphael
01:26
created

MultiGPUTrainer.fix_costs()   A

Complexity

Conditions 2

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 0
loc 2
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import time
5
import os
6
from collections import OrderedDict
7
import numpy as np
8
9
from platoon.channel import Worker
10
from platoon.param_sync import EASGD
11
from deepy.trainers import GeneralNeuralTrainer
12
13
import logging
14
15
16
class MultiGPUTrainer(GeneralNeuralTrainer):
17
    """
18
    General neural network trainer.
19
    """
20
    def __init__(self,
21
                 network, config=None, method=None,
22
                 server_port=5567,
23
                 start_halving_at=6, end_at=10, step_len=10,
24
                 valid_freq=1500, learning_rate=None
25
                 ):
26
        super(MultiGPUTrainer, self).__init__(network, config, method)
27
        self._report_time = False
28
        self._port = server_port
29
        self.logger = logging.getLogger('MultiGPUTrainingWorker')
30
        self.epoch = 0
31
        if not learning_rate:
32
            learning_rate = float(self.config.learning_rate.get_value())
33
        self._schedule_params = {
34
            'learning_rate': learning_rate,
35
            'start_halving_at': start_halving_at,
36
            'end_at': end_at,
37
            'step_len': step_len,
38
            'valid_freq': valid_freq
39
        }
40
41
    def create_param_map(self):
42
        param_map = OrderedDict()
43
        for i, param in enumerate(self.training_params()):
44
            param_map["param_{}".format(i)] = param
45
        return param_map
46
47
    def sync_hyperparams(self, param_map):
48
        self.logger.info("(proc {}) sync hyperparameters".format(os.getpid()))
49
        if 'epoch' in param_map:
50
            self.epoch = param_map['epoch']
51
        if 'learning_rate' in param_map:
52
            self.config.learning_rate.set_value(param_map['learning_rate'])
53
54
    def fix_costs(self):
55
        self.last_run_costs = [(a, float(b)) for (a,b) in self.last_run_costs]
56
57
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
58
        """
59
        Train the model in multi-GPU environment.
60
        """
61
        server_port = self._port
62
        param_map = self.create_param_map()
63
        # Initialize the worker
64
        worker = Worker(control_port=server_port)
65
        if self.config.learning_rate:
66
            worker.send_req({'init_schedule': self._schedule_params})
67
        self.sync_hyperparams(worker.send_req('sync_hyperparams')['sync_hyperparams'])
68
        easgd_alpha = worker.send_req('get_easgd_alpha')
69
        worker.init_shared_params(param_map.values(), param_sync_rule=EASGD(easgd_alpha))
70
        worker.copy_to_local()
71
        worker.send_req({
72
            "set_names": None,
73
            "training_names": self.training_names,
74
            "evaluation_names": self.evaluation_names
75
        })
76
        # Load all training batches, consume vast memory here
77
        self.logger.info("started process {}".format(os.getpid()))
78
        self.logger.info("(proc {}) load training data".format(os.getpid()))
79
        train_batches = list(train_set)
80
        network_callback = bool(self.network.training_callbacks)
81
        trainer_callback = bool(self._iter_callbacks)
82
        while True:
83
            resp = worker.send_req('next')
84
            if resp == 'stop':
85
                break
86
            elif resp == 'wait':
87
                time.sleep(1)
88
            elif resp == 'get_num_batches':
89
                worker.send_req({'get_num_batches_done': len(train_batches)})
90
            elif 'eval' in resp:
91
                self.best_cost = resp['best_valid_cost']
92
                worker.copy_to_local()
93
                valid_costs = None
94
                test_costs = None
95
                if valid_set:
96
                    self._run_valid(self.epoch, valid_set)
97
                    self.fix_costs()
98
                    valid_costs = self.last_run_costs
99
                if test_set:
100
                    self._run_test(self.epoch, test_set)
101
                    self.fix_costs()
102
                    test_costs = self.last_run_costs
103
                worker.send_req({
104
                    "eval_done": None,
105
                    "valid_costs": valid_costs,
106
                    "test_costs": test_costs,
107
                    "auto_save": self.config.auto_save
108
                })
109
            elif 'valid' in resp:
110
                self.best_cost = resp['best_valid_cost']
111
                worker.copy_to_local()
112
                if valid_set:
113
                    self._run_valid(self.epoch, valid_set, dry_run=True)
114
                    self.fix_costs()
115
                worker.send_req({
116
                    "valid_done": None,
117
                    "valid_costs": self.last_run_costs,
118
                    "auto_save": self.config.auto_save
119
                })
120
            elif 'train' in resp:
121
                batch_ids = resp['train']
122
                batch_costs = [[] for _ in self.training_names]
123
                for batch_id in batch_ids:
124
                    x = train_batches[batch_id]
125
                    cost_x = self.learn(*x)
126
                    for i, cost in enumerate(cost_x):
127
                        batch_costs[i].append(cost)
128
                    self.last_cost = cost_x[0]
129
                if network_callback:
130
                    self.network.training_callback()
131
                if trainer_callback:
132
                    for func in self._iter_callbacks:
133
                        func(self)
134
                worker.sync_params(synchronous=True)
135
                worker.send_req({'train_done': None, 'costs': [float(np.mean(c)) for c in batch_costs]})
136
            elif 'sync_hyperparams' in resp:
137
                self.sync_hyperparams(resp['sync_hyperparams'])
138
        worker.close()
139
        return []
140