Completed
Push — master ( 15b7f6...48255b )
by Raphael
02:17
created

OnDiskDataset.__init__()   C

Complexity

Conditions 8

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 8
dl 0
loc 20
rs 6.6666
c 5
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
#!/usr/bin/env python
5
# -*- coding: utf-8 -*-
6
7
import types
8
from . import Dataset
9
from deepy.utils import FakeGenerator, StreamPickler, global_rand
10
11
import logging as loggers
12
logging = loggers.getLogger(__name__)
13
14
class OnDiskDataset(Dataset):
15
    """
16
    Load large on-disk dataset.
17
    The data should be dumped with deepy.utils.StreamPickler.
18
    You must convert the data to mini-batches before dump it to a file.
19
    """
20
21
    def __init__(self, train_path, valid_path=None, test_path=None, train_size=None,
22
                 cached=False, post_processing=None, shuffle_memory=False, curriculum=None):
23
        self._train_path = train_path
24
        self._valid_path = valid_path
25
        self._test_path = test_path
26
        self._train_size = train_size
27
        self._cache_on_memory = cached
28
        self._cached_train_data = None
29
        self._post_processing = post_processing if post_processing else lambda x: x
30
        self._shuffle_memory = shuffle_memory
31
        self._curriculum = curriculum
32
        self._curriculum_count = 0
33
        if curriculum and not callable(curriculum):
34
            raise Exception("curriculum function must be callable")
35
        if curriculum and not cached:
36
            raise Exception("curriculum learning needs training data to be cached")
37
        if self._cache_on_memory:
38
            logging.info("Cache on memory")
39
            self._cached_train_data = list(map(self._post_processing, StreamPickler.load(open(self._train_path))))
40
            self._train_size = len(self._cached_train_data)
41
            # if self._shuffle_memory:
42
            #     logging.info("Shuffle on-memory data")
43
            #     global_rand.shuffle(self._cached_train_data)
44
45
    def curriculum_train_data(self):
46
        self._curriculum_count += 1
47
        logging.info("curriculum learning: round {}".format(self._curriculum_count))
48
        return self._curriculum(self._cached_train_data, self._curriculum_count)
49
50
    def generate_train_data(self):
51
        for data in StreamPickler.load(open(self._train_path)):
52
            yield self._post_processing(data)
53
54
    def generate_valid_data(self):
55
        for data in StreamPickler.load(open(self._valid_path)):
56
            yield self._post_processing(data)
57
58
    def generate_test_data(self):
59
        for data in StreamPickler.load(open(self._test_path)):
60
            yield self._post_processing(data)
61
62
    def train_set(self):
63
        if self._cache_on_memory:
64
            if self._shuffle_memory:
65
                logging.info("shuffle on-memory data")
66
                global_rand.shuffle(self._cached_train_data)
67
            if self._curriculum:
68
                if not isinstance(self._curriculum(self._cached_train_data, 1), types.GeneratorType):
69
                    raise Exception("Curriculum function must be a generator.")
70
                return FakeGenerator(self, "curriculum_train_data")
71
            else:
72
                return self._cached_train_data
73
        if not self._train_path:
74
            return None
75
        return FakeGenerator(self, "generate_train_data")
76
77
    def valid_set(self):
78
        if not self._valid_path:
79
            return None
80
        return FakeGenerator(self, "generate_valid_data")
81
82
    def test_set(self):
83
        if not self._test_path:
84
            return None
85
        return FakeGenerator(self, "generate_test_data")
86
87
    def train_size(self):
88
        return self._train_size
89