|
1
|
|
|
import logging |
|
2
|
|
|
|
|
3
|
|
|
import numpy |
|
4
|
|
|
import zmq |
|
5
|
|
|
from numpy.lib.format import header_data_from_array_1_0 |
|
6
|
|
|
|
|
7
|
|
|
from fuel.utils import buffer_ |
|
8
|
|
|
|
|
9
|
|
|
logger = logging.getLogger(__name__) |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
def send_arrays(socket, arrays, stop=False): |
|
13
|
|
|
"""Send NumPy arrays using the buffer interface and some metadata. |
|
14
|
|
|
|
|
15
|
|
|
Parameters |
|
16
|
|
|
---------- |
|
17
|
|
|
socket : :class:`zmq.Socket` |
|
18
|
|
|
The socket to send data over. |
|
19
|
|
|
arrays : list |
|
20
|
|
|
A list of :class:`numpy.ndarray` to transfer. |
|
21
|
|
|
stop : bool, optional |
|
22
|
|
|
Instead of sending a series of NumPy arrays, send a JSON object |
|
23
|
|
|
with a single `stop` key. The :func:`recv_arrays` will raise |
|
24
|
|
|
``StopIteration`` when it receives this. |
|
25
|
|
|
|
|
26
|
|
|
Notes |
|
27
|
|
|
----- |
|
28
|
|
|
The protocol is very simple: A single JSON object describing the array |
|
29
|
|
|
format (using the same specification as ``.npy`` files) is sent first. |
|
30
|
|
|
Subsequently the arrays are sent as bytestreams (through NumPy's |
|
31
|
|
|
support of the buffering protocol). |
|
32
|
|
|
|
|
33
|
|
|
""" |
|
34
|
|
|
if arrays: |
|
35
|
|
|
# The buffer protocol only works on contiguous arrays |
|
36
|
|
|
arrays = [numpy.ascontiguousarray(array) for array in arrays] |
|
37
|
|
|
if stop: |
|
38
|
|
|
headers = {'stop': True} |
|
39
|
|
|
socket.send_json(headers) |
|
40
|
|
|
else: |
|
41
|
|
|
headers = [header_data_from_array_1_0(array) for array in arrays] |
|
42
|
|
|
socket.send_json(headers, zmq.SNDMORE) |
|
43
|
|
|
for array in arrays[:-1]: |
|
44
|
|
|
socket.send(array, zmq.SNDMORE) |
|
45
|
|
|
socket.send(arrays[-1]) |
|
46
|
|
|
|
|
47
|
|
|
|
|
48
|
|
|
def recv_arrays(socket): |
|
49
|
|
|
"""Receive a list of NumPy arrays. |
|
50
|
|
|
|
|
51
|
|
|
Parameters |
|
52
|
|
|
---------- |
|
53
|
|
|
socket : :class:`zmq.Socket` |
|
54
|
|
|
The socket to receive the arrays on. |
|
55
|
|
|
|
|
56
|
|
|
Returns |
|
57
|
|
|
------- |
|
58
|
|
|
list |
|
59
|
|
|
A list of :class:`numpy.ndarray` objects. |
|
60
|
|
|
|
|
61
|
|
|
Raises |
|
62
|
|
|
------ |
|
63
|
|
|
StopIteration |
|
64
|
|
|
If the first JSON object received contains the key `stop`, |
|
65
|
|
|
signifying that the server has finished a single epoch. |
|
66
|
|
|
|
|
67
|
|
|
""" |
|
68
|
|
|
headers = socket.recv_json() |
|
69
|
|
|
if 'stop' in headers: |
|
70
|
|
|
raise StopIteration |
|
71
|
|
|
arrays = [] |
|
72
|
|
|
for header in headers: |
|
73
|
|
|
data = socket.recv() |
|
74
|
|
|
buf = buffer_(data) |
|
75
|
|
|
array = numpy.frombuffer(buf, dtype=numpy.dtype(header['descr'])) |
|
76
|
|
|
array.shape = header['shape'] |
|
77
|
|
|
if header['fortran_order']: |
|
78
|
|
|
array.shape = header['shape'][::-1] |
|
79
|
|
|
array = array.transpose() |
|
80
|
|
|
arrays.append(array) |
|
81
|
|
|
return arrays |
|
82
|
|
|
|
|
83
|
|
|
|
|
84
|
|
|
def start_server(data_stream, port=5557, hwm=10): |
|
85
|
|
|
"""Start a data processing server. |
|
86
|
|
|
|
|
87
|
|
|
This command starts a server in the current process that performs the |
|
88
|
|
|
actual data processing (by retrieving data from the given data stream). |
|
89
|
|
|
It also starts a second process, the broker, which mediates between the |
|
90
|
|
|
server and the client. The broker also keeps a buffer of batches in |
|
91
|
|
|
memory. |
|
92
|
|
|
|
|
93
|
|
|
Parameters |
|
94
|
|
|
---------- |
|
95
|
|
|
data_stream : :class:`.DataStream` |
|
96
|
|
|
The data stream to return examples from. |
|
97
|
|
|
port : int, optional |
|
98
|
|
|
The port the server and the client (training loop) will use to |
|
99
|
|
|
communicate. Defaults to 5557. |
|
100
|
|
|
hwm : int, optional |
|
101
|
|
|
The `ZeroMQ high-water mark (HWM) |
|
102
|
|
|
<http://zguide.zeromq.org/page:all#High-Water-Marks>`_ on the |
|
103
|
|
|
sending socket. Increasing this increases the buffer, which can be |
|
104
|
|
|
useful if your data preprocessing times are very random. However, |
|
105
|
|
|
it will increase memory usage. There is no easy way to tell how |
|
106
|
|
|
many batches will actually be queued with a particular HWM. |
|
107
|
|
|
Defaults to 10. Be sure to set the corresponding HWM on the |
|
108
|
|
|
receiving end as well. |
|
109
|
|
|
|
|
110
|
|
|
""" |
|
111
|
|
|
logging.basicConfig(level='INFO') |
|
112
|
|
|
|
|
113
|
|
|
context = zmq.Context() |
|
114
|
|
|
socket = context.socket(zmq.PUSH) |
|
115
|
|
|
socket.set_hwm(hwm) |
|
116
|
|
|
socket.bind('tcp://*:{}'.format(port)) |
|
117
|
|
|
|
|
118
|
|
|
it = data_stream.get_epoch_iterator() |
|
119
|
|
|
|
|
120
|
|
|
logger.info('server started') |
|
121
|
|
|
while True: |
|
122
|
|
|
try: |
|
123
|
|
|
data = next(it) |
|
124
|
|
|
stop = False |
|
125
|
|
|
logger.info("sending {} arrays".format(len(data))) |
|
126
|
|
|
except StopIteration: |
|
127
|
|
|
it = data_stream.get_epoch_iterator() |
|
128
|
|
|
data = None |
|
129
|
|
|
stop = True |
|
130
|
|
|
logger.info("sending StopIteration") |
|
131
|
|
|
send_arrays(socket, data, stop=stop) |
|
132
|
|
|
|