Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

DataStream   A

Complexity

Total Complexity 18

Size/Duplication

Total Lines 62
Duplicated Lines 0 %
Metric Value
dl 0
loc 62
rs 10
wmc 18

8 Methods

Rating   Name   Duplication   Size   Complexity  
C __init__() 0 16 8
A next_epoch() 0 2 1
A close() 0 2 1
A default_stream() 0 4 1
A get_data() 0 3 1
A get_epoch_iterator() 0 7 2
A reset() 0 3 1
A sources() 0 5 1
1
from abc import ABCMeta, abstractmethod
2
3
import zmq
4
from six import add_metaclass, iteritems
5
6
from fuel.iterator import DataIterator
7
from fuel.server import recv_arrays
8
9
10
@add_metaclass(ABCMeta)
11
class AbstractDataStream(object):
12
    """A stream of data separated into epochs.
13
14
    A data stream is an iterable stream of examples/minibatches. It shares
15
    similarities with Python file handles return by the ``open`` method.
16
    Data streams can be closed using the :meth:`close` method and reset
17
    using :meth:`reset` (similar to ``f.seek(0)``).
18
19
    Parameters
20
    ----------
21
    iteration_scheme : :class:`.IterationScheme`, optional
22
        The iteration scheme to use when retrieving data. Note that not all
23
        datasets support the same iteration schemes, some datasets require
24
        one, and others don't support any. In case when the data stream
25
        wraps another data stream, the choice of supported iteration
26
        schemes is typically even more limited. Be sure to read the
27
        documentation of the dataset or data stream in question.
28
    axis_labels : dict, optional
29
        Maps source names to tuples of strings describing axis semantics,
30
        one per axis. Defaults to `None`, i.e. no information is available.
31
32
    Attributes
33
    ----------
34
    iteration_scheme : :class:`.IterationScheme`
35
        The iteration scheme used to retrieve data. Can be ``None`` when
36
        not used.
37
    sources : tuple of strings
38
        The names of the data sources returned by this data stream, as
39
        given by the dataset.
40
    produces_examples : bool
41
        Whether this data stream produces examples (as opposed to batches
42
        of examples).
43
44
    """
45
    def __init__(self, iteration_scheme=None, axis_labels=None):
46
        self.iteration_scheme = iteration_scheme
47
        self.axis_labels = axis_labels
48
49
    @property
50
    def produces_examples(self):
51
        if self.iteration_scheme:
52
            return self.iteration_scheme.requests_examples
53
        elif not hasattr(self, '_produces_examples'):
54
            raise ValueError("cannot infer type of stream for {} instance; "
55
                             "set the produces_examples attribute to True "
56
                             "(for example streams) or False (for batch "
57
                             "streams).".format(self.__class__.__name__))
58
        else:
59
            return self._produces_examples
60
61
    @produces_examples.setter
62
    def produces_examples(self, value):
63
        if self.iteration_scheme:
64
            raise ValueError("cannot set produces_examples on {} instance; "
65
                             "determined by iteration scheme {}".format(
66
                                 self.__class__.__name__,
67
                                 self.iteration_scheme))
68
        self._produces_examples = value
69
70
    @abstractmethod
71
    def get_data(self, request=None):
72
        """Request data from the dataset or the wrapped stream.
73
74
        Parameters
75
        ----------
76
        request : object
77
            A request fetched from the `request_iterator`.
78
79
        """
80
81
    @abstractmethod
82
    def reset(self):
83
        """Reset the data stream."""
84
85
    @abstractmethod
86
    def close(self):
87
        """Gracefully close the data stream, e.g. releasing file handles."""
88
89
    @abstractmethod
90
    def next_epoch(self):
91
        """Switch the data stream to the next epoch."""
92
93
    @abstractmethod
94
    def get_epoch_iterator(self, as_dict=False):
95
        return DataIterator(self, self.iteration_scheme.get_request_iterator()
96
                            if self.iteration_scheme else None,
97
                            as_dict=as_dict)
98
99
    def iterate_epochs(self, as_dict=False):
100
        """Allow iteration through all epochs.
101
102
        Notes
103
        -----
104
        This method uses the :meth:`get_epoch_iterator` method to retrieve
105
        the :class:`DataIterator` for each epoch. The default
106
        implementation of this method resets the state of the data stream
107
        so that the new epoch can read the data from the beginning.
108
        However, this behavior only works as long as the ``epochs``
109
        property is iterated over using e.g. ``for epoch in
110
        stream.epochs``. If you create the data iterators in advance (e.g.
111
        using ``for i, epoch in zip(range(10), stream.epochs`` in legacy
112
        Python) you must call the :meth:`reset` method yourself.
113
114
        """
115
        while True:
116
            yield self.get_epoch_iterator(as_dict=as_dict)
117
118
119
class DataStream(AbstractDataStream):
120
    """A stream of data from a dataset.
121
122
    Parameters
123
    ----------
124
    dataset : instance of :class:`Dataset`
125
        The dataset from which the data is fetched.
126
127
    """
128
    def __init__(self, dataset, **kwargs):
129
        if dataset.axis_labels:
130
            kwargs.setdefault('axis_labels', dataset.axis_labels.copy())
131
        super(DataStream, self).__init__(**kwargs)
132
        # A DataStream with no iteration scheme is considered an example stream
133
        # by default
134
        if not self.iteration_scheme:
135
            self.produces_examples = True
136
        # If the data stream produces examples, remove 'batch' from axis labels
137
        if self.produces_examples and self.axis_labels:
138
            for source, labels in iteritems(self.axis_labels):
139
                self.axis_labels[source] = tuple(
140
                    label for label in labels if label != 'batch')
141
        self.dataset = dataset
142
        self.data_state = self.dataset.open()
143
        self._fresh_state = True
144
145
    @property
146
    def sources(self):
147
        if hasattr(self, '_sources'):
148
            return self._sources
149
        return self.dataset.sources
150
151
    @sources.setter
152
    def sources(self, value):
153
        self._sources = value
154
155
    def close(self):
156
        self.data_state = self.dataset.close(self.data_state)
157
158
    def reset(self):
159
        self.data_state = self.dataset.reset(self.data_state)
160
        self._fresh_state = True
161
162
    def next_epoch(self):
163
        self.data_state = self.dataset.next_epoch(self.data_state)
164
165
    def get_data(self, request=None):
166
        """Get data from the dataset."""
167
        return self.dataset.get_data(self.data_state, request)
168
169
    def get_epoch_iterator(self, **kwargs):
170
        """Get an epoch iterator for the data stream."""
171
        if not self._fresh_state:
172
            self.next_epoch()
173
        else:
174
            self._fresh_state = False
175
        return super(DataStream, self).get_epoch_iterator(**kwargs)
176
177
    @classmethod
178
    def default_stream(cls, dataset, **kwargs):
179
        data_stream = cls(dataset, **kwargs)
180
        return dataset.apply_default_transformers(data_stream)
181
182
183
class ServerDataStream(AbstractDataStream):
184
    """A data stream that receives batches from a Fuel server.
185
186
    Parameters
187
    ----------
188
    sources : tuple of strings
189
        The names of the data sources returned by this data stream.
190
    produces_examples : bool
191
        Whether this data stream produces examples (as opposed to batches
192
        of examples).
193
    host : str, optional
194
        The host to connect to. Defaults to ``localhost``.
195
    port : int, optional
196
        The port to connect on. Defaults to 5557.
197
    hwm : int, optional
198
        The `ZeroMQ high-water mark (HWM)
199
        <http://zguide.zeromq.org/page:all#High-Water-Marks>`_ on the
200
        receiving socket. Increasing this increases the buffer, which can
201
        be useful if your data preprocessing times are very random.
202
        However, it will increase memory usage. There is no easy way to
203
        tell how many batches will actually be queued with a particular
204
        HWM. Defaults to 10. Be sure to set the corresponding HWM on the
205
        server's end as well.
206
    axis_labels : dict, optional
207
        Maps source names to tuples of strings describing axis semantics,
208
        one per axis. Defaults to `None`, i.e. no information is available.
209
210
    """
211
    def __init__(self, sources, produces_examples, host='localhost', port=5557,
212
                 hwm=10, axis_labels=None):
213
        super(ServerDataStream, self).__init__(axis_labels=axis_labels)
214
        self.sources = sources
215
        self.produces_examples = produces_examples
216
        self.host = host
217
        self.port = port
218
        self.hwm = hwm
219
        self.connect()
220
221
    def connect(self):
222
        context = zmq.Context()
223
        self.socket = socket = context.socket(zmq.PULL)
224
        socket.set_hwm(self.hwm)
225
        socket.connect("tcp://{}:{}".format(self.host, self.port))
226
        self.connected = True
227
228
    def get_data(self, request=None):
229
        if request is not None:
230
            raise ValueError
231
        if not self.connected:
232
            self.connect()
233
        data = recv_arrays(self.socket)
234
        return tuple(data)
235
236
    def get_epoch_iterator(self, **kwargs):
237
        return super(ServerDataStream, self).get_epoch_iterator(**kwargs)
238
239
    def close(self):
240
        pass
241
242
    def next_epoch(self):
243
        pass
244
245
    def reset(self):
246
        pass
247
248
    def __getstate__(self):
249
        state = self.__dict__.copy()
250
        state['connected'] = False
251
        del state['socket']
252
        return state
253