| Total Complexity | 5 |
| Total Lines | 33 |
| Duplicated Lines | 0 % |
| 1 | import six |
||
| 4 | class DataIterator(six.Iterator): |
||
| 5 | """An iterator over data, representing a single epoch. |
||
| 6 | |||
| 7 | Parameters |
||
| 8 | ---------- |
||
| 9 | data_stream : :class:`DataStream` or :class:`Transformer` |
||
| 10 | The data stream over which to iterate. |
||
| 11 | request_iterator : iterator |
||
| 12 | An iterator which returns the request to pass to the data stream |
||
| 13 | for each step. |
||
| 14 | as_dict : bool, optional |
||
| 15 | If `True`, return dictionaries mapping source names to data |
||
| 16 | from each source. If `False` (default), return tuples in the |
||
| 17 | same order as `data_stream.sources`. |
||
| 18 | |||
| 19 | """ |
||
| 20 | def __init__(self, data_stream, request_iterator=None, as_dict=False): |
||
| 21 | self.data_stream = data_stream |
||
| 22 | self.request_iterator = request_iterator |
||
| 23 | self.as_dict = as_dict |
||
| 24 | |||
| 25 | def __iter__(self): |
||
| 26 | return self |
||
| 27 | |||
| 28 | def __next__(self): |
||
| 29 | if self.request_iterator is not None: |
||
| 30 | data = self.data_stream.get_data(next(self.request_iterator)) |
||
| 31 | else: |
||
| 32 | data = self.data_stream.get_data() |
||
| 33 | if self.as_dict: |
||
| 34 | return dict(zip(self.data_stream.sources, data)) |
||
| 35 | else: |
||
| 36 | return data |
||
| 37 |