Completed
Push — master ( d47136...8eb192 )
by Tinghui
57s
created

CASASFuel.get_set_list()   A

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 7
rs 9.4285
c 0
b 0
f 0
1
import os
2
import pickle
3
import logging
4
from fuel.datasets import H5PYDataset
5
6
logger = logging.getLogger(__name__)
7
8
9
class CASASFuel(object):
10
    """CASASFuel Class to retrieve CASAS smart home data as a fuel dataset object
11
12
    Args:
13
        dir_name (:obj:`string`):
14
            Directory path that contains HDF5 dataset file and complementary dataset information pkl file
15
16
    Attributes:
17
        data_filename (:obj:`str`): Path to `data.hdf5` dataset file
18
        info (:obj:`dict`): complementary dataset information stored in dict format
19
            keys of info includes:
20
21
    """
22
    def __init__(self, dir_name):
23
        logger.debug('Load Casas H5PYDataset from ' + dir_name)
24
        self.data_filename = dir_name + '/data.hdf5'
25
        if os.path.isfile(dir_name + '/info.pkl'):
26
            f = open(dir_name + '/info.pkl', 'rb')
27
            self.info = pickle.load(f)
28
            f.close()
29
        else:
30
            logger.error('Cannot find info.pkl from current H5PYDataset directory %s' % dir_name)
31
32
    def get_dataset(self, which_sets, load_in_memory=False, **kwargs):
33
        """Return fuel dataset object specified by which_sets tuple and load it in memory
34
35
        Args:
36
            which_sets (:obj:`tuple` of :obj:`str`):  containing the name of splits to load.
37
                Valid value are determined by the ``info.pkl`` loaded.
38
                You can get the list of split set names by :meth:`get_set_list()`.
39
                Usually, if the dataset is split by weeks, the split name is in the form of ``week <num>``.
40
                If the dataset is split by days, the split name is in the form of ``day <num>``.
41
            load_in_memory (:obj:`bool`, Optional): Default to False.
42
                Whether to load the data in main memory.
43
44
        Returns:
45
            :class:`fuel.datasets.base.Dataset`: A Fuel dataset object created by
46
                :class:`fuel.datasets.h5py.H5PYDataset`
47
        """
48
        # Check if sets exist as split name in metadata
49
        for set_name in which_sets:
50
            if set_name not in self.info['split_sets']:
51
                logger.error('set %s not found in splits' % set_name)
52
        # Load specified splits and return
53
        return H5PYDataset(file_or_path=self.data_filename,
54
                           which_sets=which_sets,
55
                           load_in_memory=load_in_memory, **kwargs)
56
57
    def get_set_list(self):
58
        """Get the split set list
59
60
        Returns:
61
            :obj:`tuple` of :obj:`str`: A list of split set names
62
        """
63
        return self.info['split_sets']
64
65
    def get_input_dims(self):
66
        """Get the dimension of features
67
68
        Returns:
69
            :obj:`int` : the input feature length
70
        """
71
        dims = len(self.info['index_to_feature'])
72
        return dims
73
74
    def get_output_dims(self):
75
        """Get the dimension of target indices
76
77
        Returns:
78
            :obj:`int` : the target indices
79
        """
80
        dims = len(self.info['index_to_activity'])
81
        return dims
82
83
    def get_activity_by_index(self, index):
84
        """Get activity name by index
85
86
        Args:
87
            index (:obj:`int`): Activity index
88
89
        Returns:
90
            :obj:`str`: Activity label
91
        """
92
        activity_len = len(self.info['index_to_activity'])
93
        if index < activity_len:
94
            return self.info['index_to_activity'][index]
95
        else:
96
            logger.error('Activity index %d out of bound. Dataset has %d activities' % (index, activity_len))
97
            return ''
98
99
    def get_feature_by_index(self, index):
100
        """Get feature string by index
101
102
        Args:
103
            index (:obj:`int`): Feature index
104
105
        Returns:
106
            :obj:`str`: Feature string
107
        """
108
        feature_len = len(self.info['index_to_feature'])
109
        if index < feature_len:
110
            return self.info['index_to_feature'][index]
111
        else:
112
            logger.error('Feature index %d out of bound. Dataset has %d features' % (index, feature_len))
113
            return ''
114
115