CASASFuel.get_input_dims()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 8
rs 9.4285
c 0
b 0
f 0
1
import os
2
import pickle
3
import logging
4
import numpy as np
5
from fuel.datasets import H5PYDataset
6
7
logger = logging.getLogger(__name__)
8
9
10
class CASASFuel(object):
11
    """CASASFuel Class to retrieve CASAS smart home data as a fuel dataset object
12
13
    Args:
14
        dir_name (:obj:`string`):
15
            Directory path that contains HDF5 dataset file and complementary dataset information pkl file
16
17
    Attributes:
18
        data_filename (:obj:`str`): Path to `data.hdf5` dataset file
19
        info (:obj:`dict`): complementary dataset information stored in dict format
20
            keys of info includes:
21
    """
22
    def __init__(self, dir_name):
23
        logger.debug('Load Casas H5PYDataset from ' + dir_name)
24
        self.data_filename = dir_name + '/data.hdf5'
25
        if os.path.isfile(dir_name + '/info.pkl'):
26
            f = open(dir_name + '/info.pkl', 'rb')
27
            self.info = pickle.load(f)
28
            f.close()
29
        else:
30
            logger.error('Cannot find info.pkl from current H5PYDataset directory %s' % dir_name)
31
32
    def get_dataset(self, which_sets, load_in_memory=False, **kwargs):
33
        """Return fuel dataset object specified by which_sets tuple and load it in memory
34
35
        Args:
36
            which_sets (:obj:`tuple` of :obj:`str`):  containing the name of splits to load.
37
                Valid value are determined by the ``info.pkl`` loaded.
38
                You can get the list of split set names by :meth:`get_set_list()`.
39
                Usually, if the dataset is split by weeks, the split name is in the form of ``week <num>``.
40
                If the dataset is split by days, the split name is in the form of ``day <num>``.
41
            load_in_memory (:obj:`bool`, Optional): Default to False.
42
                Whether to load the data in main memory.
43
44
        Returns:
45
            :class:`fuel.datasets.base.Dataset`: A Fuel dataset object created by
46
                :class:`fuel.datasets.h5py.H5PYDataset`
47
        """
48
        # Check if sets exist as split name in metadata
49
        for set_name in which_sets:
50
            if set_name not in self.info['split_sets']:
51
                logger.error('set %s not found in splits' % set_name)
52
        # Load specified splits and return
53
        return H5PYDataset(file_or_path=self.data_filename,
54
                           which_sets=which_sets,
55
                           load_in_memory=load_in_memory, **kwargs)
56
57
    def get_set_list(self):
58
        """Get the split set list
59
60
        Returns:
61
            :obj:`tuple` of :obj:`str`: A list of split set names
62
        """
63
        return self.info['split_sets']
64
65
    def get_input_dims(self):
66
        """Get the dimension of features
67
68
        Returns:
69
            :obj:`int` : the input feature length
70
        """
71
        dims = len(self.info['index_to_feature'])
72
        return dims
73
74
    def get_output_dims(self):
75
        """Get the dimension of target indices
76
77
        Returns:
78
            :obj:`int` : the target indices
79
        """
80
        dims = len(self.info['index_to_activity'])
81
        return dims
82
83
    def get_activity_by_index(self, index):
84
        """Get activity name by index
85
86
        Args:
87
            index (:obj:`int`): Activity index
88
89
        Returns:
90
            :obj:`str`: Activity label
91
        """
92
        activity_len = len(self.info['index_to_activity'])
93
        if index < activity_len:
94
            return self.info['index_to_activity'][index]
95
        else:
96
            logger.error('Activity index %d out of bound. Dataset has %d activities' % (index, activity_len))
97
            return ''
98
99
    def get_feature_by_index(self, index):
100
        """Get feature string by index
101
102
        Args:
103
            index (:obj:`int`): Feature index
104
105
        Returns:
106
            :obj:`str`: Feature string
107
        """
108
        feature_len = len(self.info['index_to_feature'])
109
        if index < feature_len:
110
            return self.info['index_to_feature'][index]
111
        else:
112
            logger.error('Feature index %d out of bound. Dataset has %d features' % (index, feature_len))
113
            return ''
114
115
    def back_annotate(self, fp, prediction, split_id=-1, split_name=None):
116
        """Back annotated predictions of a split set into file pointer
117
118
        Args:
119
            fp (:obj:`file`): File object to the back annotation file.
120
            prediction (:obj:`numpy.ndarray`): Numpy array containing prediction labels.
121
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
122
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
123
        """
124
        time_array = self._get_time_array(split_id=split_id, split_name=split_name)
125
        # Check length of prediction and time array
126
        if prediction.shape[0] != len(time_array):
127
            logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' %
128
                         (len(time_array), prediction.shape[0]))
129
            return
130
        # Perform back annotation
131
        for i in range(len(time_array)):
132
            if prediction[i] != -1:
133
                fp.write('%s %s\n' % (time_array[i].strftime('%Y-%m-%d %H:%M:%S'),
134
                                      self.get_activity_by_index(prediction[i])))
135
136
    def back_annotate_with_proba(self, fp, prediction_proba, split_id=-1, split_name=None, top_n=-1):
137
        """Back annotated prediction probabilities of a split set into file pointer
138
139
        Args:
140
            fp (:obj:`file`): File object to the back annotation file.
141
            prediction_proba (:obj:`numpy.ndarray`): Numpy array containing probability for each class in shape
142
                of (num_samples, num_class).
143
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
144
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
145
            top_n (:obj:`int`): Back annotate top n probabilities.
146
        """
147
        time_array = self._get_time_array(split_id=split_id, split_name=split_name)
148
        # Check length of prediction and time array
149
        if prediction_proba.shape[0] != len(time_array):
150
            logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' %
151
                         (len(time_array), prediction_proba.shape[0]))
152
            return
153
        if top_n == -1:
154
            top_n = self.get_output_dims()
155
        # Perform back annotation
156
        for i in range(len(time_array)):
157
            sorted_index = np.argsort(prediction_proba[i, :])[::-1]
158
            if prediction_proba[i, sorted_index[0]] != -1:
159
                fp.write('%s' % time_array[i].strftime('%Y-%m-%d %H:%M:%S'))
160
                for j in range(top_n):
161
                    fp.write(', %s(%g)' % (self.get_activity_by_index(sorted_index[j]),
162
                                           prediction_proba[i, sorted_index[j]]))
163
                fp.write('\n')
164
165
    def _get_time_array(self, split_id=-1, split_name=None):
166
        """Get Time Array based for specified splits
167
168
        Args:
169
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
170
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
171
172
        Returns:
173
            :obj:`list` of :obj:`datetime.datetime`: List of event.rst datetime objects of splits specified.
174
        """
175
        if split_id == -1:
176
            if type(split_name) is tuple:
177
                time_array = []
178
                for each_split in split_name:
179
                    each_split_id = self.info['split_sets'].index(each_split)
180
                    if 0 < each_split_id < len(self.info['split_sets']):
181
                        time_array += self.info['split_timearray'][each_split_id]
182
            else:
183
                if split_name in self.info['split_sets']:
184
                    split_id = self.info['split_sets'].index(split_name)
185
                    if 0 < split_id < len(self.info['split_sets']):
186
                        time_array = self.info['split_timearray'][split_id]
187
                else:
188
                    logger.error('Failed to find split set with name %s.' % split_name)
189
                    return None
190
        elif -1 < split_id < len(self.info['split_sets']):
191
            time_array = self.info['split_timearray'][split_id]
192
        else:
193
            logger.error('Split set index %d out of bound.' % split_id)
194
            return None
195
        return time_array
196
197
    @staticmethod
198
    def files_exist(dir_name):
199
        """Check if the CASAS Fuel dataset files exist under dir_name
200
        """
201
        data_filename = os.path.join(dir_name, 'data.hdf5')
202
        info_filename = os.path.join(dir_name, 'info.pkl')
203
        return os.path.isfile(data_filename) and os.path.isfile(info_filename)
204