Completed
Pull Request — master (#280)
by Dmitry
01:40
created

fuel.start_server()   A

Complexity

Conditions 3

Size

Total Lines 48

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 48
rs 9.125
1
import logging
2
3
import numpy
4
import zmq
5
from numpy.lib.format import header_data_from_array_1_0
6
7
from fuel.utils import buffer_
8
9
logger = logging.getLogger(__name__)
10
11
12
def send_arrays(socket, arrays, stop=False):
13
    """Send NumPy arrays using the buffer interface and some metadata.
14
15
    Parameters
16
    ----------
17
    socket : :class:`zmq.Socket`
18
        The socket to send data over.
19
    arrays : list
20
        A list of :class:`numpy.ndarray` to transfer.
21
    stop : bool, optional
22
        Instead of sending a series of NumPy arrays, send a JSON object
23
        with a single `stop` key. The :func:`recv_arrays` will raise
24
        ``StopIteration`` when it receives this.
25
26
    Notes
27
    -----
28
    The protocol is very simple: A single JSON object describing the array
29
    format (using the same specification as ``.npy`` files) is sent first.
30
    Subsequently the arrays are sent as bytestreams (through NumPy's
31
    support of the buffering protocol).
32
33
    """
34
    if arrays:
35
        # The buffer protocol only works on contiguous arrays
36
        arrays = [numpy.ascontiguousarray(array) for array in arrays]
37
    if stop:
38
        headers = {'stop': True}
39
        socket.send_json(headers)
40
    else:
41
        headers = [header_data_from_array_1_0(array) for array in arrays]
42
        socket.send_json(headers, zmq.SNDMORE)
43
        for array in arrays[:-1]:
44
            socket.send(array, zmq.SNDMORE)
45
        socket.send(arrays[-1])
46
47
48
def recv_arrays(socket):
49
    """Receive a list of NumPy arrays.
50
51
    Parameters
52
    ----------
53
    socket : :class:`zmq.Socket`
54
        The socket to receive the arrays on.
55
56
    Returns
57
    -------
58
    list
59
        A list of :class:`numpy.ndarray` objects.
60
61
    Raises
62
    ------
63
    StopIteration
64
        If the first JSON object received contains the key `stop`,
65
        signifying that the server has finished a single epoch.
66
67
    """
68
    headers = socket.recv_json()
69
    if 'stop' in headers:
70
        raise StopIteration
71
    arrays = []
72
    for header in headers:
73
        data = socket.recv()
74
        buf = buffer_(data)
75
        array = numpy.frombuffer(buf, dtype=numpy.dtype(header['descr']))
76
        array.shape = header['shape']
77
        if header['fortran_order']:
78
            array.shape = header['shape'][::-1]
79
            array = array.transpose()
80
        arrays.append(array)
81
    return arrays
82
83
84
def start_server(data_stream, port=5557, hwm=10):
85
    """Start a data processing server.
86
87
    This command starts a server in the current process that performs the
88
    actual data processing (by retrieving data from the given data stream).
89
    It also starts a second process, the broker, which mediates between the
90
    server and the client. The broker also keeps a buffer of batches in
91
    memory.
92
93
    Parameters
94
    ----------
95
    data_stream : :class:`.DataStream`
96
        The data stream to return examples from.
97
    port : int, optional
98
        The port the server and the client (training loop) will use to
99
        communicate. Defaults to 5557.
100
    hwm : int, optional
101
        The `ZeroMQ high-water mark (HWM)
102
        <http://zguide.zeromq.org/page:all#High-Water-Marks>`_ on the
103
        sending socket. Increasing this increases the buffer, which can be
104
        useful if your data preprocessing times are very random.  However,
105
        it will increase memory usage. There is no easy way to tell how
106
        many batches will actually be queued with a particular HWM.
107
        Defaults to 10. Be sure to set the corresponding HWM on the
108
        receiving end as well.
109
110
    """
111
    logging.basicConfig(level='INFO')
112
113
    context = zmq.Context()
114
    socket = context.socket(zmq.PUSH)
115
    socket.set_hwm(hwm)
116
    socket.bind('tcp://*:{}'.format(port))
117
118
    it = data_stream.get_epoch_iterator()
119
120
    logger.info('server started')
121
    while True:
122
        try:
123
            data = next(it)
124
            stop = False
125
            logger.info("sending {} arrays".format(len(data)))
126
        except StopIteration:
127
            it = data_stream.get_epoch_iterator()
128
            data = None
129
            stop = True
130
            logger.info("sending StopIteration")
131
        send_arrays(socket, data, stop=stop)
132