OnDiskDataset   A
last analyzed

Complexity

Total Complexity 23

Size/Duplication

Total Lines 71
Duplicated Lines 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
c 4
b 0
f 0
dl 0
loc 71
rs 10
wmc 23

9 Methods

Rating   Name   Duplication   Size   Complexity  
A _process_data() 0 5 2
C __init__() 0 24 7
A test_set() 0 4 2
A generate_test_data() 0 4 2
A train_set() 0 4 2
A valid_set() 0 4 2
A train_size() 0 2 1
A generate_train_data() 0 5 3
A generate_valid_data() 0 4 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
#!/usr/bin/env python
5
# -*- coding: utf-8 -*-
6
7
import types
8
from . import Dataset
9
from data_processor import DataProcessor
10
from deepy.utils import FakeGenerator, StreamPickler
11
from deepy.core import env
12
13
import logging as loggers
14
logging = loggers.getLogger(__name__)
15
16
class OnDiskDataset(Dataset):
17
    """
18
    Load large on-disk dataset.
19
    The data should be dumped with deepy.utils.StreamPickler.
20
    You must convert the data to mini-batches before dump it to a file.
21
    """
22
23
    def __init__(self, train_path, valid_path=None, test_path=None, train_size=None,
24
                 cached=False, post_processing=None, shuffle_memory=False, data_processor=None):
25
        """
26
        :type data_processor: DataProcessor
27
        """
28
        self._train_path = train_path
29
        self._valid_path = valid_path
30
        self._test_path = test_path
31
        self._train_size = train_size
32
        self._cache_on_memory = cached
33
        self._cached_train_data = None
34
        self._post_processing = post_processing if post_processing else lambda x: x
35
        self._shuffle_memory = shuffle_memory
36
        self._epoch = 0
37
        self._data_processor = data_processor
38
        if data_processor and not isinstance(data_processor, DataProcessor):
39
            raise Exception("data_processor must be an instance of DataProcessor.")
40
        if self._cache_on_memory:
41
            logging.info("Cache on memory")
42
            self._cached_train_data = list(map(self._post_processing, StreamPickler.load(open(self._train_path))))
43
            self._train_size = len(self._cached_train_data)
44
            if self._shuffle_memory:
45
                logging.info("Shuffle on-memory data")
46
                env.numpy_rand.shuffle(self._cached_train_data)
47
48
    def _process_data(self, split, epoch, dataset):
49
        if self._data_processor:
50
            return self._data_processor.process(split, epoch, dataset)
51
        else:
52
            return dataset
53
54
    def generate_train_data(self):
55
        self._epoch += 1
56
        data_source = self._cached_train_data if self._cache_on_memory else StreamPickler.load(open(self._train_path))
57
        for data in self._process_data('train', self._epoch, data_source):
58
            yield self._post_processing(data)
59
60
    def generate_valid_data(self):
61
        data_source = StreamPickler.load(open(self._valid_path))
62
        for data in self._process_data('valid', self._epoch, data_source):
63
            yield self._post_processing(data)
64
65
    def generate_test_data(self):
66
        data_source = StreamPickler.load(open(self._test_path))
67
        for data in self._process_data('test', self._epoch, data_source):
68
            yield self._post_processing(data)
69
70
    def train_set(self):
71
        if not self._train_path:
72
            return None
73
        return FakeGenerator(self, "generate_train_data")
74
75
    def valid_set(self):
76
        if not self._valid_path:
77
            return None
78
        return FakeGenerator(self, "generate_valid_data")
79
80
    def test_set(self):
81
        if not self._test_path:
82
            return None
83
        return FakeGenerator(self, "generate_test_data")
84
85
    def train_size(self):
86
        return self._train_size
87