Completed
Push — master ( 2384cb...f17ea4 )
by Tinghui
01:12
created

CASASFuel.back_annotate()   F

Complexity

Conditions 11

Size

Total Lines 40

Duplication

Lines 20
Ratio 50 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 11
c 3
b 0
f 0
dl 20
loc 40
rs 3.1764

How to fix   Complexity   

Complexity

Complex classes like CASASFuel.back_annotate() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import os
2
import pickle
3
import logging
4
import numpy as np
5
from fuel.datasets import H5PYDataset
6
7
logger = logging.getLogger(__name__)
8
9
10
class CASASFuel(object):
11
    """CASASFuel Class to retrieve CASAS smart home data as a fuel dataset object
12
13
    Args:
14
        dir_name (:obj:`string`):
15
            Directory path that contains HDF5 dataset file and complementary dataset information pkl file
16
17
    Attributes:
18
        data_filename (:obj:`str`): Path to `data.hdf5` dataset file
19
        info (:obj:`dict`): complementary dataset information stored in dict format
20
            keys of info includes:
21
    """
22
    def __init__(self, dir_name):
23
        logger.debug('Load Casas H5PYDataset from ' + dir_name)
24
        self.data_filename = dir_name + '/data.hdf5'
25
        if os.path.isfile(dir_name + '/info.pkl'):
26
            f = open(dir_name + '/info.pkl', 'rb')
27
            self.info = pickle.load(f)
28
            f.close()
29
        else:
30
            logger.error('Cannot find info.pkl from current H5PYDataset directory %s' % dir_name)
31
32
    def get_dataset(self, which_sets, load_in_memory=False, **kwargs):
33
        """Return fuel dataset object specified by which_sets tuple and load it in memory
34
35
        Args:
36
            which_sets (:obj:`tuple` of :obj:`str`):  containing the name of splits to load.
37
                Valid value are determined by the ``info.pkl`` loaded.
38
                You can get the list of split set names by :meth:`get_set_list()`.
39
                Usually, if the dataset is split by weeks, the split name is in the form of ``week <num>``.
40
                If the dataset is split by days, the split name is in the form of ``day <num>``.
41
            load_in_memory (:obj:`bool`, Optional): Default to False.
42
                Whether to load the data in main memory.
43
44
        Returns:
45
            :class:`fuel.datasets.base.Dataset`: A Fuel dataset object created by
46
                :class:`fuel.datasets.h5py.H5PYDataset`
47
        """
48
        # Check if sets exist as split name in metadata
49
        for set_name in which_sets:
50
            if set_name not in self.info['split_sets']:
51
                logger.error('set %s not found in splits' % set_name)
52
        # Load specified splits and return
53
        return H5PYDataset(file_or_path=self.data_filename,
54
                           which_sets=which_sets,
55
                           load_in_memory=load_in_memory, **kwargs)
56
57
    def get_set_list(self):
58
        """Get the split set list
59
60
        Returns:
61
            :obj:`tuple` of :obj:`str`: A list of split set names
62
        """
63
        return self.info['split_sets']
64
65
    def get_input_dims(self):
66
        """Get the dimension of features
67
68
        Returns:
69
            :obj:`int` : the input feature length
70
        """
71
        dims = len(self.info['index_to_feature'])
72
        return dims
73
74
    def get_output_dims(self):
75
        """Get the dimension of target indices
76
77
        Returns:
78
            :obj:`int` : the target indices
79
        """
80
        dims = len(self.info['index_to_activity'])
81
        return dims
82
83
    def get_activity_by_index(self, index):
84
        """Get activity name by index
85
86
        Args:
87
            index (:obj:`int`): Activity index
88
89
        Returns:
90
            :obj:`str`: Activity label
91
        """
92
        activity_len = len(self.info['index_to_activity'])
93
        if index < activity_len:
94
            return self.info['index_to_activity'][index]
95
        else:
96
            logger.error('Activity index %d out of bound. Dataset has %d activities' % (index, activity_len))
97
            return ''
98
99
    def get_feature_by_index(self, index):
100
        """Get feature string by index
101
102
        Args:
103
            index (:obj:`int`): Feature index
104
105
        Returns:
106
            :obj:`str`: Feature string
107
        """
108
        feature_len = len(self.info['index_to_feature'])
109
        if index < feature_len:
110
            return self.info['index_to_feature'][index]
111
        else:
112
            logger.error('Feature index %d out of bound. Dataset has %d features' % (index, feature_len))
113
            return ''
114
115
    def back_annotate(self, fp, prediction, split_id=-1, split_name=None):
116
        """Back annotated predictions of a split set into file pointer
117
118
        Args:
119
            fp (:obj:`file`): File object to the back annotation file.
120
            prediction (:obj:`numpy.ndarray`): Numpy array containing prediction labels.
121
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
122
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
123
        """
124
        # Verify split id first
125 View Code Duplication
        if split_id == -1:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
126
            if type(split_name) is tuple:
127
                time_array = []
128
                for each_split in split_name:
129
                    each_split_id = self.info['split_sets'].index(each_split)
130
                    if 0 < each_split_id < len(self.info['split_sets']):
131
                        time_array += self.info['split_timearray'][each_split_id]
132
            else:
133
                if split_name in self.info['split_sets']:
134
                    split_id = self.info['split_sets'].index(split_name)
135
                    if 0 < split_id < len(self.info['split_sets']):
136
                        time_array = self.info['split_timearray'][split_id]
137
                else:
138
                    logger.error('Failed to find split set with name %s.' % split_name)
139
                    return
140
        elif 0 < split_id < len(self.info['split_sets']):
141
            time_array = self.info['split_timearray'][split_id]
142
        else:
143
            logger.error('Split set index %d out of bound.' % split_id)
144
            return
145
        # Check length of prediction and time array
146
        if prediction.shape[0] != len(time_array):
147
            logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' %
148
                         (len(time_array), prediction.shape[0]))
149
            return
150
        # Perform back annotation
151
        for i in range(len(time_array)):
152
            if prediction[i] != -1:
153
                fp.write('%s %s\n' % (time_array[i].strftime('%Y-%m-%d %H:%M:%S'),
154
                                      self.get_activity_by_index(prediction[i])))
155
156
    def back_annotate_with_proba(self, fp, prediction_proba, split_id=-1, split_name=None, top_n=-1):
157
        """Back annotated prediction probabilities of a split set into file pointer
158
159
        Args:
160
            fp (:obj:`file`): File object to the back annotation file.
161
            prediction_proba (:obj:`numpy.ndarray`): Numpy array containing probability for each class in shape
162
                of (num_samples, num_class).
163
            split_id (:obj:`int`): The index of split set to be annotated (required if split_name not specified).
164
            split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified).
165
            top_n (:obj:`int`): Back annotate top n probabilities.
166
        """
167
        # Verify split id first
168 View Code Duplication
        if split_id == -1:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
169
            if type(split_name) is tuple:
170
                time_array = []
171
                for each_split in split_name:
172
                    each_split_id = self.info['split_sets'].index(each_split)
173
                    if 0 < each_split_id < len(self.info['split_sets']):
174
                        time_array += self.info['split_timearray'][each_split_id]
175
            else:
176
                if split_name in self.info['split_sets']:
177
                    split_id = self.info['split_sets'].index(split_name)
178
                    if 0 < split_id < len(self.info['split_sets']):
179
                        time_array = self.info['split_timearray'][split_id]
180
                else:
181
                    logger.error('Failed to find split set with name %s.' % split_name)
182
                    return
183
        elif 0 < split_id < len(self.info['split_sets']):
184
            time_array = self.info['split_timearray'][split_id]
185
        else:
186
            logger.error('Split set index %d out of bound.' % split_id)
187
            return
188
        # Check length of prediction and time array
189
        if prediction_proba.shape[0] != len(time_array):
190
            logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' %
191
                         (len(time_array), prediction_proba.shape[0]))
192
            return
193
        if top_n == -1:
194
            top_n = self.get_output_dims()
195
        # Perform back annotation
196
        for i in range(len(time_array)):
197
            sorted_index = np.argsort(prediction_proba[i, :])[::-1]
198
            if prediction_proba[i, sorted_index[0]] != -1:
199
                fp.write('%s' % time_array[i].strftime('%Y-%m-%d %H:%M:%S'))
200
                for j in range(top_n):
201
                    fp.write(', %s(%g)' % (self.get_activity_by_index(sorted_index[j]),
202
                                           prediction_proba[i, sorted_index[j]]))
203
                fp.write('\n')
204
205
    @staticmethod
206
    def files_exist(dir_name):
207
        """Check if the CASAS Fuel dataset files exist under dir_name
208
        """
209
        data_filename = os.path.join(dir_name, 'data.hdf5')
210
        info_filename = os.path.join(dir_name, 'info.pkl')
211
        return os.path.isfile(data_filename) and os.path.isfile(info_filename)
212