Completed
Push — master ( 91b7c0...d52c79 )
by Raphael
01:31
created

MultiGPUTrainer   A

Complexity

Total Complexity 25

Size/Duplication

Total Lines 86
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
dl 0
loc 86
rs 10
c 1
b 0
f 0
wmc 25

4 Methods

Rating   Name   Duplication   Size   Complexity  
F train() 0 64 19
A create_param_map() 0 5 2
A __init__() 0 4 1
A sync_hyperparams() 0 6 3
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import time
5
import os
6
from collections import OrderedDict
7
import numpy as np
8
9
from platoon.channel import Worker
10
from platoon.param_sync import EASGD
11
from deepy.trainers import GeneralNeuralTrainer
12
13
import logging
14
15
class MultiGPUTrainer(GeneralNeuralTrainer):
16
    """
17
    General neural network trainer.
18
    """
19
    def __init__(self, network, config=None, method=None):
20
        super(MultiGPUTrainer, self).__init__(network, config, method)
21
        self.logger = logging.getLogger('MultiGPUTrainingWorker')
22
        self.epoch = 0
23
24
    def create_param_map(self):
25
        param_map = OrderedDict()
26
        for i, param in enumerate(self.training_params()):
27
            param_map["param_{}".format(i)] = param
28
        return param_map
29
30
    def sync_hyperparams(self, param_map):
31
        self.logger.info("(proc {}) sync hyperparameters".format(os.getpid()))
32
        if 'epoch' in param_map:
33
            self.epoch = param_map['epoch']
34
        if 'learning_rate' in param_map:
35
            self.config.learning_rate.set_value(param_map['learning_rate'])
36
37
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
38
        """
39
        Train the model in multi-GPU environment.
40
        """
41
        server_port = self.config.get("server_port", 5567)
42
        param_map = self.create_param_map()
43
        # Initialize the worker
44
        worker = Worker(control_port=server_port)
45
        self.sync_hyperparams(worker.send_req('sync_hyperparams')['sync_hyperparams'])
46
        easgd_alpha = worker.send_req('get_easgd_alpha')
47
        worker.init_shared_params(param_map.values(), param_sync_rule=EASGD(easgd_alpha))
48
        worker.copy_to_local()
49
        # Load all training batches, consume vast memory here
50
        self.logger.info("started process {}".format(os.getpid()))
51
        self.logger.info("(proc {}) load training data".format(os.getpid()))
52
        train_batches = list(train_set)
53
        network_callback = bool(self.network.training_callbacks)
54
        trainer_callback = bool(self._iter_callbacks)
55
        while True:
56
            resp = worker.send_req('next')
57
            if resp == 'stop':
58
                break
59
            elif resp == 'wait':
60
                time.sleep(1)
61
            elif resp == 'get_num_batches':
62
                worker.send_req({'get_num_batches_done': len(train_batches)})
63
            elif resp == 'eval':
64
                worker.copy_to_local()
65
                messages = []
66
                if valid_set:
67
                    self._run_valid(self.epoch, valid_set)
68
                    messages.append(self.network.train_logger.log_pool[-1])
69
                if test_set:
70
                    self._run_test(self.epoch, test_set)
71
                    messages.append(self.network.train_logger.log_pool[-1])
72
                worker.send_req({"eval_done": messages})
73
            elif resp == 'valid':
74
                worker.copy_to_local()
75
                messages = []
76
                if valid_set:
77
                    # TODO: set and send the best cost
78
                    self._run_valid(self.epoch, valid_set, dry_run=True)
79
                    messages.append(self.network.train_logger.log_pool[-1])
80
                worker.send_req({"valid_done": messages})
81
            elif 'train' in resp:
82
                batch_ids = resp['train']
83
                batch_costs = [[] for _ in self.training_names]
84
                for batch_id in batch_ids:
85
                    x = train_batches[batch_id]
86
                    cost_x = self.learn(*x)
87
                    for i, cost in enumerate(cost_x):
88
                        batch_costs[i].append(cost)
89
                    self.last_cost = cost_x[0]
90
                if network_callback:
91
                    self.network.training_callback()
92
                if trainer_callback:
93
                    for func in self._iter_callbacks:
94
                        func(self)
95
                worker.sync_params(synchronous=True)
96
                worker.send_req({'train_done': None, 'costs': [float(np.mean(c)) for c in batch_costs]})
97
            elif 'sync_hyperparams' in resp:
98
                self.sync_hyperparams(resp['sync_hyperparams'])
99
        worker.close()
100
        return []
101