Completed
Push — master ( 19696c...444ad5 )
by Tinghui
01:03
created

compare_per_class_recall()   A

Complexity

Conditions 1

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 19
rs 9.4285
1
import logging
2
import numpy as np
3
import matplotlib.pyplot as plt
4
from matplotlib.ticker import MultipleLocator
5
from matplotlib.patches import Rectangle
6
7
logger = logging.getLogger(__name__)
8
9
recall_scoring_labels = ['Correct', 'Fragmenting', 'Underfill-B', 'Underfill-E', 'Deletion']
10
fpr_scoring_labels = ['Correct', 'Merging', 'Overfill-B', 'Overfill-E', 'Insertion']
11
recall_scoring_indices = {'C': 0, 'D': 4, 'F': 1, 'U': 2, 'u': 3}
12
fpr_scoring_indices = {'C': 0, 'I': 4, 'M': 1, 'O': 2, 'o': 3}
13
14
15
def draw_per_class_recall(classes, class_colors, recall_array, filename=None):
16
    """Draw recall array
17
    """
18
    recall_np = np.empty((len(classes), len(recall_scoring_labels)),
19
                         dtype=np.float)
20
    for i, row in enumerate(recall_array):
21
        for key in recall_scoring_indices:
22
            recall_np[i, recall_scoring_indices[key]] = row[key]
23
    recall_np /= np.sum(recall_np, axis=1, keepdims=True)
24
25
    ind = np.arange(len(classes))
26
    width = 0.35
27
    bottom = np.zeros((len(classes),))
28
    bar_array = []
29
    for i in range(len(recall_scoring_labels)):
30
        bar_array.append(plt.bar(ind, recall_np[:, i], width,
31
                                 alpha=(1-1/len(recall_scoring_labels) * i),
32
                                 color=class_colors, bottom=bottom)[0])
33
        bottom += recall_np[:, i]
34
    plt.ylabel('Percentage')
35
    plt.xlabel('Classes')
36
    plt.xticks(ind, classes, rotation='vertical')
37
    plt.legend(bar_array, recall_scoring_labels)
38
    plt.show()
39
40
41
def _get_bg_class_id(classes, background_class):
42
    # Verify Background Class first
43
    if background_class is not None:
44
        bg_class_id = classes.index(background_class)
45
    else:
46
        bg_class_id = -1
47
    return bg_class_id
48
49
50
def _get_metric_label_dict(metric_name='recall'):
51
    if metric_name == 'recall':
52
        metric_labels = recall_scoring_labels
53
        metric_indices = recall_scoring_indices
54
    else:
55
        metric_labels = fpr_scoring_labels
56
        metric_indices = fpr_scoring_indices
57
    return metric_labels, metric_indices
58
59
60
def _gether_per_class_metrics(methods, classes, metric_arrays, as_percent, metric_labels, metric_indices):
61
    """Prepare metrics for bar plot
62
    """
63
    # Gather data for bar plot
64
    plot_metric_arrays = []
65
    for j in range(len(methods)):
66
        cur_metric = np.empty((len(classes), len(metric_labels)),
67
                              dtype=np.float)
68
        for i, row in enumerate(metric_arrays[j]):
69
            for key in metric_indices:
70
                cur_metric[i, metric_indices[key]] = row[key]
71
        # As percent
72
        if as_percent:
73
            cur_metric /= np.sum(cur_metric, axis=1, keepdims=True)
74
        # Append the metric for current methods
75
        plot_metric_arrays.append(cur_metric)
76
    return plot_metric_arrays
77
78
79
def _compare_per_class_metrics(methods, classes, class_colors, metric_arrays,
80
                               group_by='methods', filename=None, background_class=None,
81
                               as_percent=True, metric_name='recall'):
82
    """Compare per-class metrics between methods using bar-graph
83
    """
84
    metric_labels, metric_indices = _get_metric_label_dict(metric_name=metric_name)
85
    bg_class_id = _get_bg_class_id(classes, background_class)
86
    plot_metric_arrays = _gether_per_class_metrics(methods, classes, metric_arrays, as_percent,
87
                                                   metric_labels, metric_indices)
88
    # Prepare Data and x-label
89
    xtick_labels = []
90
    bar_colors = []
91
    if bg_class_id < 0:
92
        plot_data = np.empty((len(methods) * len(classes), len(metric_labels)))
93
    else:
94
        plot_data = np.empty((len(methods) * (len(classes) - 1), len(metric_labels)))
95
    # Fill plot data with values
96
    if group_by == 'methods':
97
        num_base_axis = len(methods)
98
        if bg_class_id < 0:
99
            num_sec_axis = len(classes)
100
        else:
101
            num_sec_axis = len(classes) - 1
102
        for j in range(len(classes)):
103
            if bg_class_id < 0 or j < bg_class_id:
104
                for i in range(len(methods)):
105
                    bar_colors.append(class_colors[j])
106
                    xtick_labels.append(methods[i])
107
                    plot_data[j * num_base_axis + i, :] = plot_metric_arrays[i][j, :]
108
            elif j > bg_class_id:
109
                for i in range(len(methods)):
110
                    bar_colors.append(class_colors[j])
111
                    xtick_labels.append(methods[i])
112
                    plot_data[(j-1) * num_base_axis + i, :] = plot_metric_arrays[i][j, :]
113
    else:
114
        if bg_class_id < 0:
115
            num_base_axis = len(classes)
116
        else:
117
            num_base_axis = len(classes) - 1
118
        num_sec_axis = len(methods)
119
        for j in range(len(methods)):
120
            xtick_labels.append(methods[j])
121
            for i in range(len(classes)):
122
                if bg_class_id < 0 or i < bg_class_id:
123
                    bar_colors.append(class_colors[i])
124
                    plot_data[j * num_base_axis + i, :] = plot_metric_arrays[j][i, :]
125
                elif i > bg_class_id:
126
                    bar_colors.append(class_colors[i])
127
                    plot_data[j * num_base_axis + i - 1, :] = plot_metric_arrays[j][i, :]
128
    # Calculate width and bar location
129
    width = 1/(num_base_axis + 1)
130
    ind = []
131
    for i in range(num_sec_axis):
132
        for j in range(num_base_axis):
133
            ind.append(i + j * width + width)
134
    bottom = np.zeros((num_base_axis * num_sec_axis,))
135
    # Set major and minor lines for y_axis
136
    if as_percent:
137
        minor_locator_value = 0.05
138
        major_locator_value = 0.2
139
    else:
140
        max_value = np.max(plot_data.sum(axis=1)) + 20
141
        minor_locator_value = int(max_value/20)
142
        major_locator_value = int(max_value/5)
143
    # Set up x_label location
144
    xlabel_ind = []
145
    if group_by == 'methods':
146
        xlabel_ind = [x + width/2 for x in ind]
147
        xlabel_rotation = 'vertical'
148
    else:
149
        xlabel_ind = [x + 0.5 for x in range(len(methods))]
150
        xlabel_rotation = 'horizontal'
151
    # Setup Figure
152
    fig, ax = plt.subplots()
153
    # Y-Axis
154
    minor_locator = MultipleLocator(minor_locator_value)
155
    major_locator = MultipleLocator(major_locator_value)
156
    ax.yaxis.set_minor_locator(minor_locator)
157
    ax.yaxis.set_major_locator(major_locator)
158
    ax.yaxis.grid(which="major", color='0.65', linestyle='-', linewidth=1)
159
    ax.yaxis.grid(which="minor", color='0.45', linestyle=' ', linewidth=1)
160
    # Plot Bar
161
    for i in range(len(metric_labels)):
162
        ax.bar(ind, plot_data[:, i], width,
163
               alpha=(1-1/len(metric_labels) * i),
164
               color=bar_colors, bottom=bottom)
165
        bottom += plot_data[:, i]
166
    if as_percent:
167
        plt.ylabel('Percentage')
168
    else:
169
        plt.ylabel('Count')
170
    plt.xlabel('Classes')
171
    plt.xticks(xlabel_ind, xtick_labels, rotation=xlabel_rotation, fontsize=6)
172
    # Prepare Legends
173
    patches = []
174
    legend_labels = []
175
    for i in range(len(metric_labels)):
176
        patches.append(Rectangle((0, 0), 0, 0, color='0.3', alpha=(1-1/len(metric_labels) * i)))
177
        legend_labels.append(metric_labels[i])
178
    for i in range(len(classes)):
179
        if i == bg_class_id:
180
            continue
181
        patches.append(Rectangle((0, 0), 0, 0, color=class_colors[i]))
182
        legend_labels.append(classes[i])
183
    plt.legend(patches, legend_labels, loc='center left', borderaxespad=0, bbox_to_anchor=(1.05, 0.5),
184
               prop={'size': 8})
185
    plt.tight_layout()
186
    plt.title('Event-based Activity Analysis - %s' % metric_name)
187
    if filename is None:
188
        plt.show()
189
    else:
190
        plt.savefig(filename, bbox_inches='tight')
191
192
193
def compare_per_class_recall(methods, classes, class_colors, recall_arrays,
194
                             group_by='methods', filename=None, background_class=None,
195
                             as_percent=True):
196
    """Draw event.rst-based comparison between methods on Recall metric.
197
    
198
    Args:
199
        methods (:obj:`list` of :obj:`str`): List of names of different methods to be plotted.
200
        classes (:obj:`list` of :obj:`str`): List of target classes.
201
        class_colors (:obj:`list` of :obj:`str`): List of RGB color for corresponding classes in the ``classes`` list.
202
        recall_arrays (:obj:`list` of :obj:`numpy.ndarray`): List of recall arrays calculated for each methods.
203
        group_by (:obj:`str`): Group the bar graph of various 'methods' first or 'classes' first. Default to 'methods'.
204
        filename (:obj:`str`): The filename to save the plot. None if display on screen with pyplot.
205
        background_class (:obj:`str`): Background class. Usually it points to ``Other_Activity``. The statistics of
206
            background_class will be omitted from the plot.
207
        as_percent (:obj:`bool`): Whether or not to convert the accumulated value to percentage.
208
    """
209
    _compare_per_class_metrics(methods, classes, class_colors, recall_arrays,
210
                               group_by=group_by, filename=filename, background_class=background_class,
211
                               as_percent=as_percent, metric_name='recall')
212
213
214
def compare_per_class_precision(methods, classes, class_colors, precision_arrays,
215
                                group_by='methods', filename=None, background_class=None,
216
                                as_percent=True):
217
    """Draw event.rst-based comparison between methods on precision metric.
218
219
    Args:
220
        methods (:obj:`list` of :obj:`str`): List of names of different methods to be plotted.
221
        classes (:obj:`list` of :obj:`str`): List of target classes.
222
        class_colors (:obj:`list` of :obj:`str`): List of RGB color for corresponding classes in the ``classes`` list.
223
        recall_arrays (:obj:`list` of :obj:`numpy.ndarray`): List of recall arrays calculated for each methods.
224
        group_by (:obj:`str`): Group the bar graph of various 'methods' first or 'classes' first. Default to 'methods'.
225
        filename (:obj:`str`): The filename to save the plot. None if display on screen with pyplot.
226
        background_class (:obj:`str`): Background class. Usually it points to ``Other_Activity``. The statistics of
227
            background_class will be omitted from the plot.
228
        as_percent (:obj:`bool`): Whether or not to convert the accumulated value to percentage.
229
    """
230
    _compare_per_class_metrics(methods, classes, class_colors, precision_arrays,
231
                               group_by=group_by, filename=filename, background_class=background_class,
232
                               as_percent=as_percent, metric_name='precision')
233
234
235
def draw_timeliness_hist(classes, class_colors, truth, prediction, truth_scoring, prediction_scoring, time_list,
236
                         background_class):
237
    """Get Timeliness Histogram for underfill and overfill
238
    """
239
    start_mismatch, stop_mismatch = _get_timeliness_measures(classes, truth, prediction,
240
                                                             truth_scoring, prediction_scoring, time_list)
241
    bg_id = _get_bg_class_id(classes, background_class)
242
    num_classes = len(classes)
243
    # Plot histogram
244
    stack_to_plot = []
245
    stack_of_colors = []
246
    stack_of_labels = []
247
    for i in range(num_classes):
248
        if i != bg_id:
249
            stack_to_plot.append(start_mismatch[i])
250
            stack_of_colors.append(class_colors[i])
251
            stack_of_labels.append(classes[i])
252
    # Histo stack
253
    bins = np.linspace(-300, 300, 100)
254
    plt.figure()
255
    patches = []
256
    for i in range(num_classes-1):
257
        patches.append(Rectangle((0, 0), 0, 0, color=stack_of_colors[i]))
258
    for i in range(num_classes-1):
259
        plt.subplot(num_classes-1, 1, i+1)
260
        plt.hist(stack_to_plot[i], bins=bins, alpha=0.7, color=stack_of_colors[i], label=stack_of_labels[i], lw=0)
261
    # plt.hist(stack_to_plot, bins=bins, alpha=0.7, color=stack_of_colors, label=stack_of_labels)
262
    plt.legend(patches, stack_of_labels, loc='center left', borderaxespad=0, bbox_to_anchor=(1.05, 0.5),
263
               prop={'size': 8})
264
    plt.show()
265
266
267
def _get_timeliness_measures(classes, truth, prediction, truth_scoring, prediction_scoring, time_list):
268
    num_classes = len(classes)
269
    start_mismatch = [list([]) for i in range(num_classes)]
270
    stop_mismatch = [list([]) for i in range(num_classes)]
271
    # For each Underfill, Overfill
272
    prev_truth = -1
273
    for i in range(truth.shape[0]):
274
        cur_truth = int(truth[i])
275
        # Overfill/Underfill only occur at the boundary of any activity event, so look for the boundary first
276
        if cur_truth != prev_truth:
277
            truth_time = time_list[i]
278
            # Check the start boundary
279 View Code Duplication
            if truth[i] == prediction[i]:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
280
                # If current prediction is correct, then it can only be overfill of current truth label.
281
                j = i - 1
282
                while j >= 0 and prediction_scoring[j] == 'O':
283
                    j -= 1
284
                # If there is no overfill for cur_truth, and the current truth and prediction are the same,
285
                # then there is no start_boundary mismatch.
286
                start_mismatch[cur_truth].append((time_list[j + 1] - truth_time).total_seconds())
287
            else:
288
                # If current prediction is incorrect, then it can only be underfill of current truth label at start
289
                # boundary.
290
                j = i
291
                while j < truth.shape[0] and truth_scoring[j] == 'U':
292
                    j += 1
293
                if j != i and j < truth.shape[0]:
294
                    start_mismatch[cur_truth].append((time_list[j-1] - truth_time).total_seconds())
295
            # Check the stop boundary
296 View Code Duplication
            if i > 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
297
                if prediction[i-1] == truth[i-1]:
298
                    # Previous prediction is correct, then it can only be overfill of previous truth.
299
                    # If there is no overfill, the stop boundary is accurate
300
                    j = i
301
                    while prediction_scoring[j] == 'o':
302
                        j += 1
303
                    stop_mismatch[prev_truth].append((time_list[j-1] - truth_time).total_seconds())
304
                else:
305
                    # Check Underfill for prev_truth (at the stop boundary)
306
                    j = i - 1
307
                    while j >= 0 and truth_scoring[j] == 'u':
308
                        j -= 1
309
                    if j != i - 1:
310
                        stop_mismatch[prev_truth].append((time_list[j + 1] - truth_time).total_seconds())
311
            if prev_truth != -1:
312
                if len(stop_mismatch[prev_truth]) > 0 and abs(stop_mismatch[prev_truth][-1]) > 1800:
313
                    logger.warning('Stop mismatch is over half an hour: %s at %d (%s) - %f' %
314
                                   (classes[prev_truth], i, time_list[i],
315
                                    stop_mismatch[prev_truth][-1]))
316
                if len(start_mismatch[cur_truth]) > 0 and abs(start_mismatch[cur_truth][-1]) > 1800:
317
                    logger.warning('Start mismatch is over half an hour: %s at %d (%s) - %f' %
318
                                   (classes[cur_truth], i, time_list[i],
319
                                    start_mismatch[cur_truth][-1]))
320
        # Update prev truth
321
        prev_truth = cur_truth
322
    # Sort all arrays
323
    for i in range(num_classes):
324
        start_mismatch[i].sort()
325
        stop_mismatch[i].sort()
326
    # Return
327
    return start_mismatch, stop_mismatch
328