Completed
Push — master ( dccd0d...37cade )
by Raphael
01:33
created

OnDiskDataset.train_set()   B

Complexity

Conditions 5

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 5
dl 0
loc 11
rs 8.5454
c 1
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
#!/usr/bin/env python
5
# -*- coding: utf-8 -*-
6
7
import types
8
from . import Dataset
9
from deepy.utils import FakeGenerator, StreamPickler, global_rand
10
11
import logging as loggers
12
logging = loggers.getLogger(__name__)
13
14
class OnDiskDataset(Dataset):
15
    """
16
    Load large on-disk dataset.
17
    The data should be dumped with deepy.utils.StreamPickler.
18
    You must convert the data to mini-batches before dump it to a file.
19
    """
20
21
    def __init__(self, train_path, valid_path=None, test_path=None, train_size=None,
22
                 cached=False, post_processing=None, shuffle_memory=False, curriculum=None):
23
        self._train_path = train_path
24
        self._valid_path = valid_path
25
        self._test_path = test_path
26
        self._train_size = train_size
27
        self._cache_on_memory = cached
28
        self._cached_train_data = None
29
        self._post_processing = post_processing if post_processing else lambda x: x
30
        self._shuffle_memory = shuffle_memory
31
        self._curriculum = curriculum
32
        self._curriculum_count = 0
33
        if curriculum and not callable(curriculum):
34
            raise Exception("curriculum function must be callable")
35
        if curriculum and not cached:
36
            raise Exception("curriculum learning needs training data to be cached")
37
        if self._cache_on_memory:
38
            logging.info("Cache on memory")
39
            self._cached_train_data = list(map(self._post_processing, StreamPickler.load(open(self._train_path))))
40
            self._train_size = len(self._cached_train_data)
41
            if self._shuffle_memory:
42
                logging.info("Shuffle on-memory data")
43
                global_rand.shuffle(self._cached_train_data)
44
45
    def curriculum_train_data(self):
46
        self._curriculum_count += 1
47
        logging.info("curriculum learning: round {}".format(self._curriculum_count))
48
        return self._curriculum(self._cached_train_data, self._curriculum_count)
49
50
    def generate_train_data(self):
51
        for data in StreamPickler.load(open(self._train_path)):
52
            yield self._post_processing(data)
53
54
    def generate_valid_data(self):
55
        for data in StreamPickler.load(open(self._valid_path)):
56
            yield self._post_processing(data)
57
58
    def generate_test_data(self):
59
        for data in StreamPickler.load(open(self._test_path)):
60
            yield self._post_processing(data)
61
62
    def train_set(self):
63
        if self._cache_on_memory:
64
            if self._curriculum:
65
                if not isinstance(self._curriculum(self._cached_train_data, 1), types.GeneratorType):
66
                    raise Exception("Curriculum function must be a generator.")
67
                return FakeGenerator(self, "curriculum_train_data")
68
            else:
69
                return self._cached_train_data
70
        if not self._train_path:
71
            return None
72
        return FakeGenerator(self, "generate_train_data")
73
74
    def valid_set(self):
75
        if not self._valid_path:
76
            return None
77
        return FakeGenerator(self, "generate_valid_data")
78
79
    def test_set(self):
80
        if not self._test_path:
81
            return None
82
        return FakeGenerator(self, "generate_test_data")
83
84
    def train_size(self):
85
        return self._train_size
86