Completed
Push — master ( e6ecc2...19696c )
by Tinghui
01:04
created

LearningResult.add_record()   B

Complexity

Conditions 3

Size

Total Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 26
rs 8.8571
1
import h5py
2
import pickle
3
import logging
4
import xlsxwriter
5
import collections
6
from collections import OrderedDict
7
from datetime import datetime
8
import numpy as np
9
from . import overall_performance_index, per_class_performance_index, get_performance_array, get_confusion_matrix
10
from .event import score_segment
11
from ..logging import logging_name
12
from ..CASAS.fuel import CASASFuel
13
14
logger = logging.getLogger(__name__)
15
16
17
class LearningResult:
18
    """LearningResult is a class that stores results of a learning run.
19
    
20
    It may be a single-shot run or a time-based analysis. The result structure holds the parameters for the model 
21
    as well as the evaluation result for easy plot.
22
    
23
    The parameters need to be set at the time of creation, such as number of total events, splits, class description,
24
    feature array. However, the prediction, event.rst-based scoring can be added and modified at run-time - in case
25
    of failure at run-time.
26
27
    Parameters:
28
        name (:obj:`str`): Name of the learning run.
29
        description (:obj:`str`): Description of the learning result.
30
        classes (:obj:`list` of :obj:`str`): List of description of target classes.
31
        num_events (:obj:`int`): Number of total entries in the test set.
32
        bg_class (:obj:`str`): Name of the class that is considered background.
33
        splits (:obj:`OrderedDict`): List of splits with name of splits as key and the size of each split as value.
34
35
    Attributes:
36
        name (:obj:`str`): Name of the learning run
37
        data (:obj:`str`): Path to the h5py dataset directory
38
        mode (:obj:`str`): valid choices are `single_shot`, `by_week` or `by_day`
39
        created_time (:obj:`float`): created time since Epoch in seconds
40
        modified_time (:obj:`float`): record modified time since Epoch in seconds
41
        overall_performance (:class:`numpy.array`): overall performance of the learning
42
        per_class_performance (:class:`numpy.array`): overall per-class performance of the learning
43
        confusion_matrix (:class:`numpy.array`): overall confusion matrix
44
        records (:obj:`collections.OrderedDict`): Ordered dictionary storing all records
45
    """
46
    def __init__(self, name, classes, num_events, bg_class=None, splits=None, description=''):
47
        cur_time = datetime.now()
48
        self.name = name
49
        self.description = description
50
        self.classes = classes
51
        self.created_time = cur_time
52
        self.modified_time = cur_time
53
        self.performance = {}
54
        self.splits = OrderedDict()
55
        if splits is not None:
56
            index = 0
57
            for name, length in splits:
58
                self.splits[name] = {
59
                    'start': index,
60
                    'stop': index+length,
61
                    'model_path': ''
62
                }
63
                index += length
64
        else:
65
            self.splits['None'] = {
66
                'start': 0,
67
                'stop': num_events,
68
                'model_path': ''
69
            }
70
        self.truth = np.empty(shape=(num_events, ), dtype=np.int)
71
        self.prediction = np.empty(shape=(num_events, ), dtype=np.int)
72
        self.time = np.empty(shape=(num_events, ), dtype='datetime64[ns]')
73
        self.num_events = num_events
74
        if bg_class is None:
75
            self.bg_class_id = -1
76
        elif bg_class in self.classes:
77
            self.bg_class_id = self.classes.index(bg_class)
78
        else:
79
            raise ValueError('Background class %s not in the target classes list.' % bg_class)
80
81
    def record_result(self, model_file, time, truth, prediction, split=None):
82
        """Record the result of a split
83
84
        Args:
85
            model_file (:obj:`str`): Path to the file that stores the model parameters
86
            split (:obj:`str`): Name of the split the record is for
87
            time (:obj:`list` of :obj:`datetime`): Corresponding datetime
88
            truth (:obj:`numpy.ndarray`): Array that holds the ground truth for the targeting split
89
            prediction (:obj:`numpy.ndarray`): Array that holds the prediction for the targeting split
90
        """
91
        split_name = str(split)
92
        if split_name not in self.splits.keys():
93
            return ValueError('Split %s not found in the result.' % split)
94
        start_pos = self.splits[split_name]['start']
95
        stop_pos = self.splits[split_name]['stop']
96
        self.truth[start_pos:stop_pos] = truth.astype(dtype=np.int)
97
        self.prediction[start_pos:stop_pos] = prediction.astype(dtype=np.int)
98
        self.time[start_pos:stop_pos] = time
99
        self.splits[split_name]['model_path'] = model_file
100
        # Calculate performance metrics for the split
101
        confusion_matrix = get_confusion_matrix(len(self.classes),
102
                                                self.truth[start_pos:stop_pos],
103
                                                self.prediction[start_pos:stop_pos]
104
                                                )
105
        self.splits[split_name]['confusion_matrix'] = confusion_matrix
106
        # After confusion metrix, one can calculate traditional multi-class performance
107
        overall_performance, per_class_performance = get_performance_array(confusion_matrix)
108
        self.splits[split_name]['overall_performance'] = overall_performance
109
        self.splits[split_name]['per_class_performance'] = per_class_performance
110
        # Note: Event-based scoring can be done after all split are logged in.
111
112
    def get_record_of_split(self, split):
113
        """Get result corresponding to specific split
114
        
115
        Args:
116
            split (:obj:`str`): Name of the split.
117
            
118
        Returns:
119
            :obj:`dict`: 
120
        """
121
        if split in self.splits.keys():
122
            return self.splits[split]
123
        else:
124
            logger.error('Cannot find split %s.' % split)
125
            return None
126
127
    def get_time_list(self):
128
        time_list = [datetime.utcfromtimestamp(item.astype(datetime) * 1e-9) for item in self.time]
129
        return time_list
130
131
    def event_based_scoring(self):
132
        """Event based segment scoring
133
        """
134
        self.performance['event_scoring'] = score_segment(self.truth, self.prediction, bg_label=self.bg_class_id)
135
136
    def calculate_overall_performance(self):
137
        """Calculate overall performance
138
        """
139
        confusion_matrix = get_confusion_matrix(len(self.classes), self.truth, self.prediction)
140
        overall_performance, per_class_performance = get_performance_array(confusion_matrix)
141
        self.performance['overall_performance'] = overall_performance
142
        self.performance['per_class_performance'] = per_class_performance
143
144
    def save_to_file(self, filename):
145
        """Pickle to file
146
        """
147
        f = open(filename, 'wb')
148
        pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
149
        f.close()
150
151
    @staticmethod
152
    def load_from_file(filename):
153
        """Load LearningResult from file
154
        
155
        Args:
156
            filename (:obj:`str`): Path to the file that stores the result.
157
        
158
        Returns:
159
            :class:`pyActLearn.performance.record.LearningResult`: LearningResult object.
160
        """
161
        f = open(filename, 'rb')
162
        result = pickle.load(f)
163
        f.close()
164
        return result
165
166
    def export_to_xlsx(self, filename, home_info=None):
167
        """Export to XLSX
168
169
        Args:
170
            filename (:obj:`str`): path to the file
171
            home_info (:class:`pyActLearn.CASAS.fuel.CASASFuel`): dataset information
172
        """
173
        workbook = xlsxwriter.Workbook(filename)
174
        num_performance = len(per_class_performance_index)
175
        num_classes = len(self.classes)
176
        # Overall Performance Summary
177
        overall_sheet = workbook.add_worksheet('overall')
178
        overall_sheet.merge_range(0, 0, 0, len(overall_performance_index) - 1, 'Overall Performance')
179
        for c in range(len(overall_performance_index)):
180
            overall_sheet.write(1, c, str(overall_performance_index[c]))
181
            overall_sheet.write(2, c, self.overall_performance[c])
182
        overall_sheet.merge_range(4, 0, 4, len(per_class_performance_index), 'Per-Class Performance')
183
        overall_sheet.write(5, 0, 'Activities')
184
        for c in range(len(per_class_performance_index)):
185
            overall_sheet.write(5, c + 1, str(per_class_performance_index[c]))
186
        for r in range(num_classes):
187
            label = home_info.get_activity_by_index(r)
188
            overall_sheet.write(r + 6, 0, label)
189
            for c in range(num_performance):
190
                overall_sheet.write(r + 6, c + 1, self.per_class_performance[r][c])
191
        overall_sheet.merge_range(8 + num_classes, 0, 8 + num_classes, num_classes, 'Confusion Matrix')
192
        for i in range(num_classes):
193
            label = home_info.get_activity_by_index(i)
194
            overall_sheet.write(9 + num_classes, i + 1, label)
195
            overall_sheet.write(10 + num_classes + i, 0, label)
196
        for r in range(num_classes):
197
            for c in range(num_classes):
198
                overall_sheet.write(10 + num_classes + r, c + 1, self.confusion_matrix[r][c])
199
200
        records = self.get_record_keys()
201
202
        # Weekly Performance Summary
203
        weekly_sheet = workbook.add_worksheet('weekly')
204
        weekly_list_title = ['dataset', '#week'] + overall_performance_index
205
        for c in range(len(weekly_list_title)):
206
            weekly_sheet.write(0, c, str(weekly_list_title[c]))
207
        r = 1
208
        for record_id in records:
209
            weekly_sheet.write(r, 0, 'b1')
210
            weekly_sheet.write(r, 1, record_id)
211
            for c in range(len(overall_performance_index)):
212
                weekly_sheet.write(r, c + 2, '%.5f' % self.get_record_by_key(record_id)['overall_performance'][c])
213
            r += 1
214
        dataset_list_title = ['activities'] + per_class_performance_index
215
        # Per Week Per Class Summary
216
        for record_id in self.get_record_keys():
217
            cur_sheet = workbook.add_worksheet(record_id)
218
            for c in range(0, len(dataset_list_title)):
219
                cur_sheet.write(0, c, str(dataset_list_title[c]))
220
            for r in range(num_classes):
221
                label = home_info.get_activity_by_index(r)
222
                cur_sheet.write(r+1, 0, label)
223
                for c in range(num_performance):
224
                    cur_sheet.write(r + 1, c + 1, self.get_record_by_key(record_id)['per_class_performance'][r][c])
225
        workbook.close()
226
227
    def export_annotation(self, filename):
228
        """Export back annotation to file
229
        """
230
        f = open(filename, 'w')
231
        for i in range(self.num_events):
232
            f.write('%s %s\n' % (
233
                datetime.utcfromtimestamp(self.time[i].astype(datetime) * 1e-9).strftime('%Y-%m-%d %H:%M:%S'),
234
                self.classes[self.prediction[i]]
235
            ))
236
        f.close()
237