Completed
Push — master ( 48255b...bf2b0c )
by Raphael
01:13
created

OnDiskDataset._process_data()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
c 0
b 0
f 0
dl 0
loc 5
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
#!/usr/bin/env python
5
# -*- coding: utf-8 -*-
6
7
import types
8
from . import Dataset
9
from data_processor import DataProcessor
10
from deepy.utils import FakeGenerator, StreamPickler, global_rand
11
12
import logging as loggers
13
logging = loggers.getLogger(__name__)
14
15
class OnDiskDataset(Dataset):
16
    """
17
    Load large on-disk dataset.
18
    The data should be dumped with deepy.utils.StreamPickler.
19
    You must convert the data to mini-batches before dump it to a file.
20
    """
21
22
    def __init__(self, train_path, valid_path=None, test_path=None, train_size=None,
23
                 cached=False, post_processing=None, shuffle_memory=False, data_processor=None):
24
        """
25
        :type data_processor: DataProcessor
26
        """
27
        self._train_path = train_path
28
        self._valid_path = valid_path
29
        self._test_path = test_path
30
        self._train_size = train_size
31
        self._cache_on_memory = cached
32
        self._cached_train_data = None
33
        self._post_processing = post_processing if post_processing else lambda x: x
34
        self._shuffle_memory = shuffle_memory
35
        self._epoch = 0
36
        self._data_processor = data_processor
37
        if data_processor and not isinstance(data_processor, DataProcessor):
38
            raise Exception("data_processor must be an instance of DataProcessor.")
39
        if self._cache_on_memory:
40
            logging.info("Cache on memory")
41
            self._cached_train_data = list(map(self._post_processing, StreamPickler.load(open(self._train_path))))
42
            self._train_size = len(self._cached_train_data)
43
            if self._shuffle_memory:
44
                logging.info("Shuffle on-memory data")
45
                global_rand.shuffle(self._cached_train_data)
46
47
    def _process_data(self, split, epoch, dataset):
48
        if self._data_processor:
49
            return self._data_processor.process(split, epoch, dataset)
50
        else:
51
            return dataset
52
53
    def generate_train_data(self):
54
        self._epoch += 1
55
        data_source = self._cached_train_data if self._cache_on_memory else StreamPickler.load(open(self._train_path))
56
        for data in self._process_data('train', self._epoch, data_source):
57
            yield self._post_processing(data)
58
59
    def generate_valid_data(self):
60
        data_source = StreamPickler.load(open(self._valid_path))
61
        for data in self._process_data('valid', self._epoch, data_source):
62
            yield self._post_processing(data)
63
64
    def generate_test_data(self):
65
        data_source = StreamPickler.load(open(self._test_path))
66
        for data in self._process_data('test', self._epoch, data_source):
67
            yield self._post_processing(data)
68
69
    def train_set(self):
70
        if not self._train_path:
71
            return None
72
        return FakeGenerator(self, "generate_train_data")
73
74
    def valid_set(self):
75
        if not self._valid_path:
76
            return None
77
        return FakeGenerator(self, "generate_valid_data")
78
79
    def test_set(self):
80
        if not self._test_path:
81
            return None
82
        return FakeGenerator(self, "generate_test_data")
83
84
    def train_size(self):
85
        return self._train_size
86