Completed
Push — master ( 23d8f0...154b9f )
by Tinghui
51s
created

LearningResult.__init__()   A

Complexity

Conditions 1

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 11
rs 9.4285
1
import time
2
import logging
3
import xlsxwriter
4
import collections
5
from . import overall_performance_index, per_class_performance_index, get_performance_array
6
from ..logging import logging_name
7
from ..CASAS.fuel import CASASFuel
8
9
logger = logging.getLogger(__file__)
10
11
12
class LearningResult:
13
    """LearningResult is a class that stores results of a learning run.
14
    It may be a single-shot run or a time-based analysis
15
    The result structure holds the parameters for the model as well as
16
    the evaluation result for easy plot.
17
18
    Parameters:
19
        name (:obj:`str`): Name of the learning run
20
        data (:obj:`str`): Name of the dataset or description of the dataset
21
        mode (:obj:`str`): valid choices are `single_shot`, `by_week` or `by_day`
22
23
    Attributes:
24
        name (:obj:`str`): Name of the learning run
25
        data (:obj:`str`): Path to the h5py dataset directory
26
        mode (:obj:`str`): valid choices are `single_shot`, `by_week` or `by_day`
27
        created_time (:obj:`float`): created time since Epoch in seconds
28
        modified_time (:obj:`float`): record modified time since Epoch in seconds
29
        overall_performance (:class:`numpy.array`): overall performance of the learning
30
        per_class_performance (:class:`numpy.array`): overall per-class performance of the learning
31
        confusion_matrix (:class:`numpy.array`): overall confusion matrix
32
        records (:obj:`collections.OrderedDict`): Ordered dictionary storing all records
33
    """
34
    def __init__(self, name='', data='', mode='single_shot'):
35
        cur_time = time.time()
36
        self.name = name
37
        self.data = data
38
        self.mode = mode
39
        self.created_time = cur_time
40
        self.modified_time = cur_time
41
        self.overall_performance = None
42
        self.per_class_performance = None
43
        self.confusion_matrix = None
44
        self.records = collections.OrderedDict()
45
46
    def get_num_records(self):
47
        """Get the length of result records in current instance
48
        """
49
        if self.records is None:
50
            return 0
51
        else:
52
            return len(self.records)
53
54
    def get_record_keys(self):
55
        """Get List of keys to all the records
56
        """
57
        if self.records is None:
58
            return []
59
        else:
60
            return self.records.keys()
61
62
    def add_record(self, model, key='single_shot', confusion_matrix=None):
63
        """Add a learning milestone record
64
65
        Args:
66
            model (:obj:`object`): snap shot of learning model parameters
67
            key (:obj:`str`): key string to represent current record
68
            confusion_matrix (:obj:`numpy.array`): Confusion Matrix
69
        """
70
        if self.get_num_records() == 0:
71
            self.confusion_matrix = confusion_matrix.copy()
72
        else:
73
            # Check confusion matrix size
74
            if confusion_matrix.shape != self.confusion_matrix.shape:
75
                logger.error(logging_name(self) + ': confusion matrix shape mismatch. Original shape %s. New shape %s'
76
                             % (str(self.confusion_matrix.shape), str(confusion_matrix.shape)))
77
            else:
78
                self.confusion_matrix += confusion_matrix
79
        self.overall_performance, self.per_class_performance = get_performance_array(self.confusion_matrix)
80
        overall_performance, per_class_performance = get_performance_array(confusion_matrix)
81
        cur_result = {
82
            'model': model,
83
            'confusion_matrix': confusion_matrix,
84
            'per_class_performance': per_class_performance,
85
            'overall_performance': overall_performance
86
        }
87
        self.records[key] = cur_result
88
89
    def get_record_by_key(self, key):
90
        """
91
        Get result corresponding to specific key
92
        :param key:
93
        :return:
94
        """
95
        if key in self.records.keys():
96
            return self.records[key]
97
        else:
98
            logger.error(logging_name(self) + ': Cannot find record %s' % key)
99
            return None
100
101
    def export_to_xlsx(self, filename, home_info=None):
102
        """Export to XLSX
103
104
        Args:
105
            filename (:obj:`str`): path to the file
106
            home_info (:class:`pyActLearn.CASAS.fuel.CASASFuel`): dataset information
107
        """
108
        if home_info is None:
109
            home_info = CASASFuel(dir_name=self.data)
110
        workbook = xlsxwriter.Workbook(filename)
111
        records = self.get_record_keys()
112
        num_performance = len(per_class_performance_index)
113
        num_classes = self.confusion_matrix.shape[0]
114
        # Overall Performance Summary
115
        overall_sheet = workbook.add_worksheet('overall')
116
        overall_sheet.merge_range(0, 0, 0, len(overall_performance_index) - 1, 'Overall Performance')
117
        for c in range(len(overall_performance_index)):
118
            overall_sheet.write(1, c, str(overall_performance_index[c]))
119
            overall_sheet.write(2, c, self.overall_performance[c])
120
        overall_sheet.merge_range(4, 0, 4, len(per_class_performance_index), 'Per-Class Performance')
121
        overall_sheet.write(5, 0, 'Activities')
122
        for c in range(len(per_class_performance_index)):
123
            overall_sheet.write(5, c + 1, str(per_class_performance_index[c]))
124
        for r in range(num_classes):
125
            label = home_info.get_activity_by_index(r)
126
            overall_sheet.write(r + 6, 0, label)
127
            for c in range(num_performance):
128
                overall_sheet.write(r + 6, c + 1, self.per_class_performance[r][c])
129
        overall_sheet.merge_range(8 + num_classes, 0, 8 + num_classes, num_classes, 'Confusion Matrix')
130
        for i in range(num_classes):
131
            label = home_info.get_activity_by_index(i)
132
            overall_sheet.write(9 + num_classes, i + 1, label)
133
            overall_sheet.write(10 + num_classes + i, 0, label)
134
        for r in range(num_classes):
135
            for c in range(num_classes):
136
                overall_sheet.write(10 + num_classes + r, c + 1, self.confusion_matrix[r][c])
137
        # Weekly Performance Summary
138
        weekly_sheet = workbook.add_worksheet('weekly')
139
        weekly_list_title = ['dataset', '#week'] + overall_performance_index
140
        for c in range(len(weekly_list_title)):
141
            weekly_sheet.write(0, c, str(weekly_list_title[c]))
142
        r = 1
143
        for record_id in records:
144
            weekly_sheet.write(r, 0, 'b1')
145
            weekly_sheet.write(r, 1, record_id)
146
            for c in range(len(overall_performance_index)):
147
                weekly_sheet.write(r, c + 2, '%.5f' % self.get_record_by_key(record_id)['overall_performance'][c])
148
            r += 1
149
        dataset_list_title = ['activities'] + per_class_performance_index
150
        # Per Week Per Class Summary
151
        for record_id in self.get_record_keys():
152
            cur_sheet = workbook.add_worksheet(record_id)
153
            for c in range(0, len(dataset_list_title)):
154
                cur_sheet.write(0, c, str(dataset_list_title[c]))
155
            for r in range(num_classes):
156
                label = home_info.get_activity_by_index(r)
157
                cur_sheet.write(r+1, 0, label)
158
                for c in range(num_performance):
159
                    cur_sheet.write(r + 1, c + 1, self.get_record_by_key(record_id)['per_class_performance'][r][c])
160
        workbook.close()
161