Completed
Push — master ( 96a8e9...51ce31 )
by Tinghui
01:08
created

CASASFuel   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 143
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
c 2
b 0
f 0
dl 0
loc 143
rs 10
wmc 19

9 Methods

Rating   Name   Duplication   Size   Complexity  
A get_input_dims() 0 8 1
A get_activity_by_index() 0 15 2
A files_exist() 0 7 1
A get_output_dims() 0 8 1
A get_feature_by_index() 0 15 2
B get_dataset() 0 24 3
A __init__() 0 9 2
B back_annotate() 0 30 6
A get_set_list() 0 7 1
1
import os
2
import pickle
3
import logging
4
from fuel.datasets import H5PYDataset
5
6
logger = logging.getLogger(__name__)
7
8
9
class CASASFuel(object):
10
    """CASASFuel Class to retrieve CASAS smart home data as a fuel dataset object
11
12
    Args:
13
        dir_name (:obj:`string`):
14
            Directory path that contains HDF5 dataset file and complementary dataset information pkl file
15
16
    Attributes:
17
        data_filename (:obj:`str`): Path to `data.hdf5` dataset file
18
        info (:obj:`dict`): complementary dataset information stored in dict format
19
            keys of info includes:
20
    """
21
    def __init__(self, dir_name):
22
        logger.debug('Load Casas H5PYDataset from ' + dir_name)
23
        self.data_filename = dir_name + '/data.hdf5'
24
        if os.path.isfile(dir_name + '/info.pkl'):
25
            f = open(dir_name + '/info.pkl', 'rb')
26
            self.info = pickle.load(f)
27
            f.close()
28
        else:
29
            logger.error('Cannot find info.pkl from current H5PYDataset directory %s' % dir_name)
30
31
    def get_dataset(self, which_sets, load_in_memory=False, **kwargs):
32
        """Return fuel dataset object specified by which_sets tuple and load it in memory
33
34
        Args:
35
            which_sets (:obj:`tuple` of :obj:`str`):  containing the name of splits to load.
36
                Valid value are determined by the ``info.pkl`` loaded.
37
                You can get the list of split set names by :meth:`get_set_list()`.
38
                Usually, if the dataset is split by weeks, the split name is in the form of ``week <num>``.
39
                If the dataset is split by days, the split name is in the form of ``day <num>``.
40
            load_in_memory (:obj:`bool`, Optional): Default to False.
41
                Whether to load the data in main memory.
42
43
        Returns:
44
            :class:`fuel.datasets.base.Dataset`: A Fuel dataset object created by
45
                :class:`fuel.datasets.h5py.H5PYDataset`
46
        """
47
        # Check if sets exist as split name in metadata
48
        for set_name in which_sets:
49
            if set_name not in self.info['split_sets']:
50
                logger.error('set %s not found in splits' % set_name)
51
        # Load specified splits and return
52
        return H5PYDataset(file_or_path=self.data_filename,
53
                           which_sets=which_sets,
54
                           load_in_memory=load_in_memory, **kwargs)
55
56
    def get_set_list(self):
57
        """Get the split set list
58
59
        Returns:
60
            :obj:`tuple` of :obj:`str`: A list of split set names
61
        """
62
        return self.info['split_sets']
63
64
    def get_input_dims(self):
65
        """Get the dimension of features
66
67
        Returns:
68
            :obj:`int` : the input feature length
69
        """
70
        dims = len(self.info['index_to_feature'])
71
        return dims
72
73
    def get_output_dims(self):
74
        """Get the dimension of target indices
75
76
        Returns:
77
            :obj:`int` : the target indices
78
        """
79
        dims = len(self.info['index_to_activity'])
80
        return dims
81
82
    def get_activity_by_index(self, index):
83
        """Get activity name by index
84
85
        Args:
86
            index (:obj:`int`): Activity index
87
88
        Returns:
89
            :obj:`str`: Activity label
90
        """
91
        activity_len = len(self.info['index_to_activity'])
92
        if index < activity_len:
93
            return self.info['index_to_activity'][index]
94
        else:
95
            logger.error('Activity index %d out of bound. Dataset has %d activities' % (index, activity_len))
96
            return ''
97
98
    def get_feature_by_index(self, index):
99
        """Get feature string by index
100
101
        Args:
102
            index (:obj:`int`): Feature index
103
104
        Returns:
105
            :obj:`str`: Feature string
106
        """
107
        feature_len = len(self.info['index_to_feature'])
108
        if index < feature_len:
109
            return self.info['index_to_feature'][index]
110
        else:
111
            logger.error('Feature index %d out of bound. Dataset has %d features' % (index, feature_len))
112
            return ''
113
114
    def back_annotate(self, fp, prediction, split_id=-1, split_name=None):
115
        """Back annotated predictions of a split set into file pointer
116
117
        Args:
118
            fp (:obj:`file`): File object to the back annotation file.
119
            prediction (:obj:`numpy.ndarray`): Numpy array containing prediction labels.
120
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
121
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
122
        """
123
        # Verify split id first
124
        if split_id == -1:
125
            if split_name in self.info['split_sets']:
126
                split_id = self.info['split_sets'].index(split_name)
127
            else:
128
                logger.error('Failed to find split set with name %s.' % split_name)
129
                return
130
        if 0 < split_id < len(self.info['split_sets']):
131
            time_array = self.info['split_timearray'][split_id]
132
        else:
133
            logger.error('Split set index %d out of bound.' % split_id)
134
            return
135
        # Check length of prediction and time array
136
        if prediction.shape[0] != len(time_array):
137
            logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' %
138
                         (len(time_array), prediction.shape[0]))
139
            return
140
        # Perform back annotation
141
        for i in range(len(time_array)):
142
            fp.write('%s %s\n' % (time_array[i].strftime('%Y-%m-%d %H:%M:%S'),
143
                                  self.get_activity_by_index(prediction[i])))
144
145
    @staticmethod
146
    def files_exist(dir_name):
147
        """Check if the CASAS Fuel dataset files exist under dir_name
148
        """
149
        data_filename = os.path.join(dir_name, 'data.hdf5')
150
        info_filename = os.path.join(dir_name, 'info.pkl')
151
        return os.path.isfile(data_filename) and os.path.isfile(info_filename)
152