_compare_per_class_metrics()   F
last analyzed

Complexity

Conditions 28

Size

Total Lines 112

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 28
c 1
b 0
f 0
dl 0
loc 112
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like _compare_per_class_metrics() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import sys
2
import logging
3
import numpy as np
4
import matplotlib.pyplot as plt
5
from matplotlib.ticker import MultipleLocator
6
from matplotlib.patches import Rectangle
7
8
logger = logging.getLogger(__name__)
9
10
recall_scoring_labels = ['Correct', 'Fragmenting', 'Underfill-B', 'Underfill-E', 'Deletion']
11
fpr_scoring_labels = ['Correct', 'Merging', 'Overfill-B', 'Overfill-E', 'Insertion']
12
recall_scoring_indices = {'C': 0, 'D': 4, 'F': 1, 'U': 2, 'u': 3}
13
fpr_scoring_indices = {'C': 0, 'I': 4, 'M': 1, 'O': 2, 'o': 3}
14
15
16
def draw_per_class_recall(classes, class_colors, recall_array, filename=None):
17
    """Draw recall array
18
    """
19
    recall_np = np.empty((len(classes), len(recall_scoring_labels)),
20
                         dtype=np.float)
21
    for i, row in enumerate(recall_array):
22
        for key in recall_scoring_indices:
23
            recall_np[i, recall_scoring_indices[key]] = row[key]
24
    recall_np /= np.sum(recall_np, axis=1, keepdims=True)
25
26
    ind = np.arange(len(classes))
27
    width = 0.35
28
    bottom = np.zeros((len(classes),))
29
    bar_array = []
30
    for i in range(len(recall_scoring_labels)):
31
        bar_array.append(plt.bar(ind, recall_np[:, i], width,
32
                                 alpha=(1-1/len(recall_scoring_labels) * i),
33
                                 color=class_colors, bottom=bottom)[0])
34
        bottom += recall_np[:, i]
35
    plt.ylabel('Percentage')
36
    plt.xlabel('Classes')
37
    plt.xticks(ind, classes, rotation='vertical')
38
    plt.legend(bar_array, recall_scoring_labels)
39
    plt.show()
40
41
42
def _get_bg_class_id(classes, background_class):
43
    # Verify Background Class first
44
    if background_class is not None:
45
        bg_class_id = classes.index(background_class)
46
    else:
47
        bg_class_id = -1
48
    return bg_class_id
49
50
51
def _get_metric_label_dict(metric_name='recall'):
52
    if metric_name == 'recall':
53
        metric_labels = recall_scoring_labels
54
        metric_indices = recall_scoring_indices
55
    else:
56
        metric_labels = fpr_scoring_labels
57
        metric_indices = fpr_scoring_indices
58
    return metric_labels, metric_indices
59
60
61
def _gether_per_class_metrics(methods, classes, metric_arrays, as_percent, metric_labels, metric_indices):
62
    """Prepare metrics for bar plot
63
    """
64
    # Gather data for bar plot
65
    plot_metric_arrays = []
66
    for j in range(len(methods)):
67
        cur_metric = np.empty((len(classes), len(metric_labels)),
68
                              dtype=np.float)
69
        for i, row in enumerate(metric_arrays[j]):
70
            for key in metric_indices:
71
                cur_metric[i, metric_indices[key]] = row[key]
72
        # As percent
73
        if as_percent:
74
            cur_metric /= np.sum(cur_metric, axis=1, keepdims=True)
75
        # Append the metric for current methods
76
        plot_metric_arrays.append(cur_metric)
77
    return plot_metric_arrays
78
79
80
def _compare_per_class_metrics(methods, classes, class_colors, metric_arrays,
81
                               group_by='methods', filename=None, background_class=None,
82
                               as_percent=True, metric_name='recall'):
83
    """Compare per-class metrics between methods using bar-graph
84
    """
85
    metric_labels, metric_indices = _get_metric_label_dict(metric_name=metric_name)
86
    bg_class_id = _get_bg_class_id(classes, background_class)
87
    plot_metric_arrays = _gether_per_class_metrics(methods, classes, metric_arrays, as_percent,
88
                                                   metric_labels, metric_indices)
89
    # Prepare Data and x-label
90
    xtick_labels = []
91
    bar_colors = []
92
    if bg_class_id < 0:
93
        plot_data = np.empty((len(methods) * len(classes), len(metric_labels)))
94
    else:
95
        plot_data = np.empty((len(methods) * (len(classes) - 1), len(metric_labels)))
96
    # Fill plot data with values
97
    if group_by == 'methods':
98
        num_base_axis = len(methods)
99
        if bg_class_id < 0:
100
            num_sec_axis = len(classes)
101
        else:
102
            num_sec_axis = len(classes) - 1
103
        for j in range(len(classes)):
104
            if bg_class_id < 0 or j < bg_class_id:
105
                for i in range(len(methods)):
106
                    bar_colors.append(class_colors[j])
107
                    xtick_labels.append(methods[i])
108
                    plot_data[j * num_base_axis + i, :] = plot_metric_arrays[i][j, :]
109
            elif j > bg_class_id:
110
                for i in range(len(methods)):
111
                    bar_colors.append(class_colors[j])
112
                    xtick_labels.append(methods[i])
113
                    plot_data[(j-1) * num_base_axis + i, :] = plot_metric_arrays[i][j, :]
114
    else:
115
        if bg_class_id < 0:
116
            num_base_axis = len(classes)
117
        else:
118
            num_base_axis = len(classes) - 1
119
        num_sec_axis = len(methods)
120
        for j in range(len(methods)):
121
            xtick_labels.append(methods[j])
122
            for i in range(len(classes)):
123
                if bg_class_id < 0 or i < bg_class_id:
124
                    bar_colors.append(class_colors[i])
125
                    plot_data[j * num_base_axis + i, :] = plot_metric_arrays[j][i, :]
126
                elif i > bg_class_id:
127
                    bar_colors.append(class_colors[i])
128
                    plot_data[j * num_base_axis + i - 1, :] = plot_metric_arrays[j][i, :]
129
    # Calculate width and bar location
130
    width = 1/(num_base_axis + 1)
131
    ind = []
132
    for i in range(num_sec_axis):
133
        for j in range(num_base_axis):
134
            ind.append(i + j * width + width)
135
    bottom = np.zeros((num_base_axis * num_sec_axis,))
136
    # Set major and minor lines for y_axis
137
    if as_percent:
138
        minor_locator_value = 0.05
139
        major_locator_value = 0.2
140
    else:
141
        max_value = np.max(plot_data.sum(axis=1)) + 20
142
        minor_locator_value = int(max_value/20)
143
        major_locator_value = int(max_value/5)
144
    # Set up x_label location
145
    xlabel_ind = []
146
    if group_by == 'methods':
147
        xlabel_ind = [x + width/2 for x in ind]
148
        xlabel_rotation = 'vertical'
149
    else:
150
        xlabel_ind = [x + 0.5 for x in range(len(methods))]
151
        xlabel_rotation = 'horizontal'
152
    # Setup Figure
153
    fig, ax = plt.subplots()
154
    # Y-Axis
155
    minor_locator = MultipleLocator(minor_locator_value)
156
    major_locator = MultipleLocator(major_locator_value)
157
    ax.yaxis.set_minor_locator(minor_locator)
158
    ax.yaxis.set_major_locator(major_locator)
159
    ax.yaxis.grid(which="major", color='0.65', linestyle='-', linewidth=1)
160
    ax.yaxis.grid(which="minor", color='0.45', linestyle=' ', linewidth=1)
161
    # Plot Bar
162
    for i in range(len(metric_labels)):
163
        ax.bar(ind, plot_data[:, i], width,
164
               alpha=(1-1/len(metric_labels) * i),
165
               color=bar_colors, bottom=bottom)
166
        bottom += plot_data[:, i]
167
    if as_percent:
168
        plt.ylabel('Percentage')
169
    else:
170
        plt.ylabel('Count')
171
    plt.xlabel('Classes')
172
    plt.xticks(xlabel_ind, xtick_labels, rotation=xlabel_rotation, fontsize=6)
173
    # Prepare Legends
174
    patches = []
175
    legend_labels = []
176
    for i in range(len(metric_labels)):
177
        patches.append(Rectangle((0, 0), 0, 0, color='0.3', alpha=(1-1/len(metric_labels) * i)))
178
        legend_labels.append(metric_labels[i])
179
    for i in range(len(classes)):
180
        if i == bg_class_id:
181
            continue
182
        patches.append(Rectangle((0, 0), 0, 0, color=class_colors[i]))
183
        legend_labels.append(classes[i])
184
    plt.legend(patches, legend_labels, loc='center left', borderaxespad=0, bbox_to_anchor=(1.05, 0.5),
185
               prop={'size': 8})
186
    plt.tight_layout()
187
    plt.title('Event-based Activity Analysis - %s' % metric_name)
188
    if filename is None:
189
        plt.show()
190
    else:
191
        plt.savefig(filename, bbox_inches='tight')
192
193
194
def compare_per_class_recall(methods, classes, class_colors, recall_arrays,
195
                             group_by='methods', filename=None, background_class=None,
196
                             as_percent=True):
197
    """Draw event.rst-based comparison between methods on Recall metric.
198
    
199
    Args:
200
        methods (:obj:`list` of :obj:`str`): List of names of different methods to be plotted.
201
        classes (:obj:`list` of :obj:`str`): List of target classes.
202
        class_colors (:obj:`list` of :obj:`str`): List of RGB color for corresponding classes in the ``classes`` list.
203
        recall_arrays (:obj:`list` of :obj:`numpy.ndarray`): List of recall arrays calculated for each methods.
204
        group_by (:obj:`str`): Group the bar graph of various 'methods' first or 'classes' first. Default to 'methods'.
205
        filename (:obj:`str`): The filename to save the plot. None if display on screen with pyplot.
206
        background_class (:obj:`str`): Background class. Usually it points to ``Other_Activity``. The statistics of
207
            background_class will be omitted from the plot.
208
        as_percent (:obj:`bool`): Whether or not to convert the accumulated value to percentage.
209
    """
210
    _compare_per_class_metrics(methods, classes, class_colors, recall_arrays,
211
                               group_by=group_by, filename=filename, background_class=background_class,
212
                               as_percent=as_percent, metric_name='recall')
213
214
215
def compare_per_class_precision(methods, classes, class_colors, precision_arrays,
216
                                group_by='methods', filename=None, background_class=None,
217
                                as_percent=True):
218
    """Draw event.rst-based comparison between methods on precision metric.
219
220
    Args:
221
        methods (:obj:`list` of :obj:`str`): List of names of different methods to be plotted.
222
        classes (:obj:`list` of :obj:`str`): List of target classes.
223
        class_colors (:obj:`list` of :obj:`str`): List of RGB color for corresponding classes in the ``classes`` list.
224
        recall_arrays (:obj:`list` of :obj:`numpy.ndarray`): List of recall arrays calculated for each methods.
225
        group_by (:obj:`str`): Group the bar graph of various 'methods' first or 'classes' first. Default to 'methods'.
226
        filename (:obj:`str`): The filename to save the plot. None if display on screen with pyplot.
227
        background_class (:obj:`str`): Background class. Usually it points to ``Other_Activity``. The statistics of
228
            background_class will be omitted from the plot.
229
        as_percent (:obj:`bool`): Whether or not to convert the accumulated value to percentage.
230
    """
231
    _compare_per_class_metrics(methods, classes, class_colors, precision_arrays,
232
                               group_by=group_by, filename=filename, background_class=background_class,
233
                               as_percent=as_percent, metric_name='precision')
234
235
236
def draw_timeliness_hist(classes, class_colors, truth, prediction, truth_scoring, prediction_scoring, time_list,
237
                         background_class):
238
    """Get Timeliness Histogram for underfill and overfill
239
    """
240
    start_mismatch, stop_mismatch = _get_timeliness_measures(classes, truth, prediction,
241
                                                             time_list)
242
    bg_id = _get_bg_class_id(classes, background_class)
243
    num_classes = len(classes)
244
    # Plot histogram
245
    stack_to_plot = []
246
    stack_of_colors = []
247
    stack_of_labels = []
248
    for i in range(num_classes):
249
        if i != bg_id:
250
            stack_to_plot.append(start_mismatch[i])
251
            stack_of_colors.append(class_colors[i])
252
            stack_of_labels.append(classes[i])
253
    # Histo stack
254
    bins = np.linspace(-300, 300, 100)
255
    plt.figure()
256
    patches = []
257
    for i in range(num_classes-1):
258
        patches.append(Rectangle((0, 0), 0, 0, color=stack_of_colors[i]))
259
    for i in range(num_classes-1):
260
        plt.subplot(num_classes-1, 1, i+1)
261
        plt.hist(stack_to_plot[i], bins=bins, alpha=0.7, color=stack_of_colors[i], label=stack_of_labels[i], lw=0)
262
    # plt.hist(stack_to_plot, bins=bins, alpha=0.7, color=stack_of_colors, label=stack_of_labels)
263
    plt.legend(patches, stack_of_labels, loc='center left', borderaxespad=0, bbox_to_anchor=(1.05, 0.5),
264
               prop={'size': 8})
265
    plt.show()
266
267
268
def _find_overlap_seg(seg_list, id):
269
    for seg_id in range(len(seg_list)):
270
        if seg_list[seg_id][1] < id:
271
            continue
272
        elif seg_list[seg_id][0] > id:
273
            return -1
274
        else:
275
            return seg_id
276
    return -1
277
278
279
def _find_seg_start_within(seg_list, start, stop):
280
    for seg_id in range(len(seg_list)):
281
        if seg_list[seg_id][0] < start:
282
            continue
283
        elif seg_list[seg_id][0] > stop:
284
            return -1
285
        else:
286
            return seg_id
287
    return -1
288
289
290
def _find_seg_end_within(seg_list, start, stop):
291
    found_seg_id = -1
292
    for seg_id in range(len(seg_list)):
293
        if seg_list[seg_id][1] < start:
294
            continue
295
        elif seg_list[seg_id][0] > stop:
296
            return found_seg_id
297
        else:
298
            found_seg_id = seg_id
299
    return found_seg_id
300
301
302 View Code Duplication
def _get_timeoffset_measures(classes, truth, prediction, time_list):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
303
    num_classes = len(classes)
304
    start_mismatch = [list([]) for i in range(num_classes)]
305
    stop_mismatch = [list([]) for i in range(num_classes)]
306
    # Processing segmentation first!
307
    for j in range(num_classes):
308
        pred_segs = []
309
        truth_segs = []
310
        prev_pred = False
311
        prev_truth = False
312
        tseg_start = 0
313
        tseg_stop = 0
314
        pseg_start = 0
315
        pseg_stop = 0
316
        for i in range(truth.shape[0]):
317
            cur_truth = (int(truth[i]) == j)
318
            cur_pred = (int(prediction[i]) == j)
319
            # Truth segments
320
            if cur_truth != prev_truth:
321
                if cur_truth:
322
                    tseg_start = i
323
                elif tseg_stop != 0:
324
                    truth_segs.append((tseg_start, tseg_stop))
325
            tseg_stop = i
326
            # Prediction segments
327
            if cur_pred != prev_pred:
328
                if cur_pred:
329
                    pseg_start = i
330
                elif pseg_stop != 0:
331
                    pred_segs.append((pseg_start, pseg_stop))
332
            pseg_stop = i
333
            prev_truth = cur_truth
334
            prev_pred = cur_pred
335
        # Add compensated segments to predictions egments
336
        for ts, (tseg_start, tseg_stop) in enumerate(truth_segs):
337
            ps = _find_overlap_seg(pred_segs, tseg_start)
338
            if ps == -1:
339
                # potential underfill or deletion
340
                ps = _find_seg_start_within(pred_segs, tseg_start, tseg_stop)
341
                if ps != -1:
342
                    pseg_start = pred_segs[ps][0]
343
                    offset = (time_list[tseg_start] - time_list[pseg_start]).total_seconds()
344
                    if abs(offset) < 18000:
345
                        start_mismatch[j].append(offset)
346
            else:
347
                pseg_start = pred_segs[ps][0]
348
                # Check the end of previous truth
349
                if ts > 1 and truth_segs[ts-1][1] >= pseg_start:
350
                    continue
351
                else:
352
                    offset = (time_list[tseg_start] - time_list[pseg_start]).total_seconds()
353
                    if abs(offset) < 18000:
354
                        # Calculate overfill
355
                        start_mismatch[j].append((time_list[tseg_start] - time_list[pseg_start]).total_seconds())
356
        for ts, (tseg_start, tseg_stop) in enumerate(truth_segs):
357
            ps = _find_overlap_seg(pred_segs, tseg_stop)
358
            if ps == -1:
359
                # potential underfill or deletion
360
                ps = _find_seg_end_within(pred_segs, tseg_start, tseg_stop)
361
                if ps != -1:
362
                    pseg_stop = pred_segs[ps][1]
363
                    offset = (time_list[tseg_stop] - time_list[pseg_stop]).total_seconds()
364
                    if tseg_stop != pseg_stop and abs(offset) < 18000:
365
                        stop_mismatch[j].append(offset)
366
            else:
367
                pseg_stop = pred_segs[ps][1]
368
                # Check the end of previous truth
369
                if ts < len(truth_segs) - 1 and truth_segs[ts-1][0] <= pseg_stop:
370
                    continue
371
                else:
372
                    offset = (time_list[tseg_stop] - time_list[pseg_stop]).total_seconds()
373
                    if abs(offset) < 18000:
374
                        # Calculate overfill
375
                        stop_mismatch[j].append(offset)
376
        # print("class: %d" % j)
377
        # print("pred_segs: %d %s" % (len(pred_segs), str(pred_segs)))
378
        # print("truth_segs: %d %s" % (len(truth_segs), str(truth_segs)))
379
        # print("start_mismatch: %s" % start_mismatch)
380
        # print("stop_mismatch: %s" % stop_mismatch)
381
    return start_mismatch, stop_mismatch
382
383
384 View Code Duplication
def _get_timeliness_measures(classes, truth, prediction, time_list):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
385
    num_classes = len(classes)
386
    start_mismatch = [list([]) for i in range(num_classes)]
387
    stop_mismatch = [list([]) for i in range(num_classes)]
388
    # Processing segmentation first!
389
    for j in range(num_classes):
390
        pred_segs = []
391
        truth_segs = []
392
        prev_pred = False
393
        prev_truth = False
394
        tseg_start = 0
395
        tseg_stop = 0
396
        pseg_start = 0
397
        pseg_stop = 0
398
        for i in range(truth.shape[0]):
399
            cur_truth = (int(truth[i]) == j)
400
            cur_pred = (int(prediction[i]) == j)
401
            # Truth segments
402
            if cur_truth != prev_truth:
403
                if cur_truth:
404
                    tseg_start = i
405
                elif tseg_stop != 0:
406
                    truth_segs.append((tseg_start, tseg_stop))
407
            tseg_stop = i
408
            # Prediction segments
409
            if cur_pred != prev_pred:
410
                if cur_pred:
411
                    pseg_start = i
412
                elif pseg_stop != 0:
413
                    pred_segs.append((pseg_start, pseg_stop))
414
            pseg_stop = i
415
            prev_truth = cur_truth
416
            prev_pred = cur_pred
417
        # Add compensated segments to predictions egments
418
        for ts, (tseg_start, tseg_stop) in enumerate(truth_segs):
419
            ps = _find_overlap_seg(pred_segs, tseg_start)
420
            if ps == -1:
421
                # potential underfill or deletion
422
                ps = _find_seg_start_within(pred_segs, tseg_start, tseg_stop)
423
                if ps != -1:
424
                    pseg_start = pred_segs[ps][0]
425
                    offset = (time_list[tseg_start] - time_list[pseg_start]).total_seconds()
426
                    if tseg_start != pseg_start and abs(offset) < 18000:
427
                        start_mismatch[j].append(offset)
428
            else:
429
                pseg_start = pred_segs[ps][0]
430
                # Check the end of previous truth
431
                if ts > 1 and truth_segs[ts-1][1] >= pseg_start:
432
                    continue
433
                else:
434
                    offset = (time_list[tseg_start] - time_list[pseg_start]).total_seconds()
435
                    if tseg_start != pseg_start and abs(offset) < 18000:
436
                        # Calculate overfill
437
                        start_mismatch[j].append((time_list[tseg_start] - time_list[pseg_start]).total_seconds())
438
        for ts, (tseg_start, tseg_stop) in enumerate(truth_segs):
439
            ps = _find_overlap_seg(pred_segs, tseg_stop)
440
            if ps == -1:
441
                # potential underfill or deletion
442
                ps = _find_seg_end_within(pred_segs, tseg_start, tseg_stop)
443
                if ps != -1:
444
                    pseg_stop = pred_segs[ps][1]
445
                    offset = (time_list[tseg_stop] - time_list[pseg_stop]).total_seconds()
446
                    if tseg_stop != pseg_stop and abs(offset) < 18000:
447
                        stop_mismatch[j].append(offset)
448
            else:
449
                pseg_stop = pred_segs[ps][1]
450
                # Check the end of previous truth
451
                if ts < len(truth_segs) - 1 and truth_segs[ts-1][0] <= pseg_stop:
452
                    continue
453
                else:
454
                    offset = (time_list[tseg_stop] - time_list[pseg_stop]).total_seconds()
455
                    if tseg_stop != pseg_stop and abs(offset) < 18000:
456
                        # Calculate overfill
457
                        stop_mismatch[j].append(offset)
458
        # print("class: %d" % j)
459
        # print("pred_segs: %d %s" % (len(pred_segs), str(pred_segs)))
460
        # print("truth_segs: %d %s" % (len(truth_segs), str(truth_segs)))
461
        # print("start_mismatch: %s" % start_mismatch)
462
        # print("stop_mismatch: %s" % stop_mismatch)
463
    return start_mismatch, stop_mismatch
464
465
466
def _get_timeliness_measures_depricated(classes, truth, prediction, truth_scoring, prediction_scoring, time_list):
467
    num_classes = len(classes)
468
    start_mismatch = [list([]) for i in range(num_classes)]
469
    stop_mismatch = [list([]) for i in range(num_classes)]
470
    # For each Underfill, Overfill
471
    prev_truth = -1
472
    for i in range(truth.shape[0]):
473
        cur_truth = int(truth[i])
474
        # Overfill/Underfill only occur at the boundary of any activity event, so look for the boundary first
475
        if cur_truth != prev_truth:
476
            truth_time = time_list[i]
477
            # Check the start boundary
478
            if truth[i] == prediction[i]:
479
                # If current prediction is correct, then it can only be overfill of current truth label.
480
                j = i - 1
481
                while j >= 0 and prediction_scoring[j] == 'O':
482
                    j -= 1
483
                # If there is no overfill for cur_truth, and the current truth and prediction are the same,
484
                # then there is no start_boundary mismatch.
485
                start_mismatch[cur_truth].append((time_list[j + 1] - truth_time).total_seconds())
486
            else:
487
                # If current prediction is incorrect, then it can only be underfill of current truth label at start
488
                # boundary.
489
                j = i
490
                while j < truth.shape[0] and truth_scoring[j] == 'U':
491
                    j += 1
492
                if j != i and j < truth.shape[0]:
493
                    start_mismatch[cur_truth].append((time_list[j-1] - truth_time).total_seconds())
494
            # Check the stop boundary
495
            if i > 0:
496
                if prediction[i-1] == truth[i-1]:
497
                    # Previous prediction is correct, then it can only be overfill of previous truth.
498
                    # If there is no overfill, the stop boundary is accurate
499
                    j = i
500
                    while prediction_scoring[j] == 'o':
501
                        j += 1
502
                    stop_mismatch[prev_truth].append((time_list[j-1] - truth_time).total_seconds())
503
                else:
504
                    # Check Underfill for prev_truth (at the stop boundary)
505
                    j = i - 1
506
                    while j >= 0 and truth_scoring[j] == 'u':
507
                        j -= 1
508
                    if j != i - 1:
509
                        stop_mismatch[prev_truth].append((time_list[j + 1] - truth_time).total_seconds())
510
            if prev_truth != -1:
511
                if len(stop_mismatch[prev_truth]) > 0 and abs(stop_mismatch[prev_truth][-1]) > 1800:
512
                    logger.warning('Stop mismatch is over half an hour: %s at %d (%s) - %f' %
513
                                   (classes[prev_truth], i, time_list[i],
514
                                    stop_mismatch[prev_truth][-1]))
515
                if len(start_mismatch[cur_truth]) > 0 and abs(start_mismatch[cur_truth][-1]) > 1800:
516
                    logger.warning('Start mismatch is over half an hour: %s at %d (%s) - %f' %
517
                                   (classes[cur_truth], i, time_list[i],
518
                                    start_mismatch[cur_truth][-1]))
519
        # Update prev truth
520
        prev_truth = cur_truth
521
    # Sort all arrays
522
    for i in range(num_classes):
523
        start_mismatch[i].sort()
524
        stop_mismatch[i].sort()
525
    # Return
526
    return start_mismatch, stop_mismatch
527
528
529
def generate_latex_table(methods, classes, recall_metrics, precision_matrics,
530
                          background_class=None, filename=None,
531
                          as_percent=True, metric_name='recall'):
532
    bg_class_id = _get_bg_class_id(classes, background_class)
533
    metric_labels, metric_indices = _get_metric_label_dict(metric_name='recall')
534
    rmp = _gether_per_class_metrics(methods, classes, recall_metrics, True,
535
                                    metric_labels, metric_indices)
536
    rmr = _gether_per_class_metrics(methods, classes, recall_metrics, False,
537
                                    metric_labels, metric_indices)
538
    metric_labels, metric_indices = _get_metric_label_dict(metric_name='precision')
539
    pmp = _gether_per_class_metrics(methods, classes, precision_matrics, True,
540
                                    metric_labels, metric_indices)
541
    pmr = _gether_per_class_metrics(methods, classes, precision_matrics, False,
542
                                    metric_labels, metric_indices)
543
    if filename is None:
544
        f = sys.stdout
545
    else:
546
        f = open(filename, 'w')
547
    f.write('\\multirow{2}{*}{Models} & \\multirow{2}{*}{Activities} & '
548
            '\\multirow{2}{*}{Total Truth} & \\multicolumn{2}{|c|}{Recall} & '
549
            '\\multirow{2}{*}{Total Prediction} & \\multicolumn{2}{|c|}{Precision}  \\\\ \\hline\n')
550
    f.write('& & & C only & U included & & C only & O included \\\\ \\hline \n')
551
    for i, method in enumerate(methods):
552
        f.write('\\multirow{%d}{*}{%s} & ' % (len(classes), method.replace('_', '\_')))
553
        for j, target in enumerate(classes):
554
            if j != 0:
555
                f.write('& ')
556
            f.write('%s & '
557
                    '%d & %d (%.2f) & %d (%.2f)  & '
558
                    '%d & %d (%.2f) & %d (%.2f)  \\\\ \n' %
559
                    (target.replace('_', '\_'),
560
                     rmr[i][j,:].sum(), rmr[i][j,0], rmp[i][j,0],
561
                     rmr[i][j,0]+rmr[i][j,1]+rmr[i][j,2], rmp[i][j,0]+rmp[i][j,1]+rmp[i][j,2],
562
                     pmr[i][j,:].sum(), pmr[i][j,0], pmp[i][j,0],
563
                     pmr[i][j,0]+pmr[i][j,1]+pmr[i][j,2], pmp[i][j,0]+pmp[i][j,1]+pmp[i][j,2],
564
                     )
565
                    )
566
        f.write('\\hline\n')
567
    f.close()
568
569
570
def generate_seg_latex_table(methods, classes, recall_metrics, precision_matrics,
571
                             background_class=None, filename=None):
572
    bg_class_id = _get_bg_class_id(classes, background_class)
573
    metric_labels, metric_indices = _get_metric_label_dict(metric_name='recall')
574
    rmp = _gether_per_class_metrics(methods, classes, recall_metrics, True,
575
                                    metric_labels, metric_indices)
576
    rmr = _gether_per_class_metrics(methods, classes, recall_metrics, False,
577
                                    metric_labels, metric_indices)
578
    metric_labels, metric_indices = _get_metric_label_dict(metric_name='precision')
579
    pmp = _gether_per_class_metrics(methods, classes, precision_matrics, True,
580
                                    metric_labels, metric_indices)
581
    pmr = _gether_per_class_metrics(methods, classes, precision_matrics, False,
582
                                    metric_labels, metric_indices)
583
    if filename is None:
584
        f = sys.stdout
585
    else:
586
        f = open(filename, 'w')
587
    f.write('Metric & Activities')
588
    for method in methods:
589
        f.write('& %s' % method.replace('_', '\_'))
590
    f.write('\\\\ \\hline \n')
591
    for i, activity in enumerate(classes):
592 View Code Duplication
        if i != bg_class_id:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
593
            if i == 0:
594
                f.write('\multirow{%d}{*}{Recall} & ' % (len(classes) - 1))
595
            else:
596
                f.write(' & ')
597
            f.write('%s ' % activity.replace('_', '\_'))
598
            # Find maximum and store index
599
            temp_array = np.array([rmp[j][i,0] for j in range(len(methods))])
600
            max_index = temp_array.argpartition(-2)[-2:]
601
            for j, method in enumerate(methods):
602
                if j in max_index:
603
                    f.write('& \\textbf{%d/%.2f\\%%} ' % (rmr[j][i,0], rmp[j][i,0]* 100))
604
                else:
605
                    f.write('& %d/%.2f\\%% ' % (rmr[j][i,0], rmp[j][i,0]* 100))
606
            f.write('\\\\ \n')
607
    f.write('\\hline \n')
608
    for i, activity in enumerate(classes):
609 View Code Duplication
        if i != bg_class_id:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
610
            if i == 0:
611
                f.write('\multirow{%d}{*}{Precision} & ' % (len(classes) - 1))
612
            else:
613
                f.write(' & ')
614
            f.write('%s ' % activity.replace('_', '\_'))
615
            # Find maximum and store index
616
            temp_array = np.array([pmp[j][i,0] for j in range(len(methods))])
617
            max_index = temp_array.argpartition(-2)[-2:]
618
            for j, method in enumerate(methods):
619
                if j in max_index:
620
                    f.write('& \\textbf{%d/%.2f\\%%} ' % (pmr[j][i,0], pmp[j][i,0]* 100))
621
                else:
622
                    f.write('& %d/%.2f\\%% ' % (pmr[j][i,0], pmp[j][i,0]* 100))
623
            f.write('\\\\ \n')
624
    f.write('\\hline \n')
625
626
627
def generate_event_recall_table(methods, classes, recall_metrics,
628
                                background_class=None, filename=None):
629
    bg_class_id = _get_bg_class_id(classes, background_class)
630
    metric_labels, metric_indices = _get_metric_label_dict(metric_name='recall')
631
    rmp = _gether_per_class_metrics(methods, classes, recall_metrics, True,
632
                                    metric_labels, metric_indices)
633
    rmr = _gether_per_class_metrics(methods, classes, recall_metrics, False,
634
                                    metric_labels, metric_indices)
635
    if filename is None:
636
        f = sys.stdout
637
    else:
638
        f = open(filename, 'w')
639
    f.write('Activities')
640
    for method in methods:
641
        f.write('& %s' % method.replace('_', '\_'))
642
    f.write('\\\\ \\hline \n')
643
    for i, activity in enumerate(classes):
644
        if i != bg_class_id:
645
            f.write(' & ')
646
            f.write('%s ' % activity.replace('_', '\_'))
647
            # Find maximum and store index
648
            temp_array = np.array([rmp[j][i, 0] for j in range(len(methods))])
649
            max_index = temp_array.argpartition(-2)[-2:]
650
            for j, method in enumerate(methods):
651
                if j in max_index:
652
                    f.write('& \\textbf{%.2f\\%%} ' % (rmp[j][i,0]* 100))
653
                else:
654
                    f.write('& %.2f\\%% ' % (rmp[j][i,0]* 100))
655
            f.write('\\\\ \n')
656
    f.write('\\hline \n')
657
    f.write('Recall (micro) &')
658
    total_correct = np.array([np.sum(rmr[j][:, 0]) - rmr[j][bg_class_id, 0] for j in range(len(methods))])
659
    total_events = np.array([total_correct[j] + np.sum(rmr[j][:, 4]) - rmr[j][bg_class_id, 4]
660
                             for j in range(len(methods))])
661
    max_index = total_correct.argpartition(-2)[-2:]
662
    for j, method in enumerate(methods):
663
        if j in max_index:
664
            f.write('& \\textbf{%.2f\\%%} ' % (total_correct[j] / total_events[j] * 100))
665
        else:
666
            f.write('& %.2f\\%% ' % (total_correct[j] / total_events[j] * 100))
667
    f.write('\\\\ \n')
668
    f.write('\\hline \n')
669
    logger.debug('Total Events: %s' % str(total_events))
670
671
672
def generate_timeliness_table(methods, classes, result_array,
673
                              background_class, filename=None):
674
    bg_class_id = _get_bg_class_id(classes, background_class)
675
    timeliness_values = []
676
    for i, method in enumerate(methods):
677
        start_mismatch, stop_mismatch = _get_timeliness_measures(classes, result_array[i][0], result_array[i][1],
678
                                                                 result_array[i][4])
679
        cur_timeliness = [start_mismatch[j] + stop_mismatch[j] for j in range(len(classes))]
680
        timeliness_values.append([np.abs(np.array(cur_timeliness[j])) for j in range(len(classes))])
681
    # Average, <60, >60
682
    if filename is None:
683
        f = sys.stdout
684
    else:
685
        f = open(filename, 'w')
686
    f.write('Activities & Metrics ')
687
    for method in methods:
688
        f.write('& %s' % method.replace('_', '\_'))
689
    f.write('\\\\ \\hline \n')
690
    for i, activity in enumerate(classes):
691
        if i != bg_class_id:
692
            f.write('\multirow{3}{*}{%s} & ' % activity.replace('_', '\_'))
693
            f.write('Average ')
694
            # Find maximum and store index
695
            for j, method in enumerate(methods):
696
                if len(timeliness_values[j][i]) == 0:
697
                    average_time = 0.
698
                else:
699
                    average_time = np.average(timeliness_values[j][i])
700
                f.write('& %.2f s' % average_time)
701
            f.write('\\\\ \n')
702
            f.write(' & ')
703
            f.write('<60s ')
704
            for j, method in enumerate(methods):
705
                number = (timeliness_values[j][i] <= 60).sum()
706
                if len(timeliness_values[j][i]) == 0:
707
                    percentage = 0.0
708
                else:
709
                    percentage = float(number)/len(timeliness_values[j][i]) * 100
710
                f.write('& %d/%.2f\\%% ' % (number, percentage))
711
            f.write('\\\\ \n')
712
            f.write(' & ')
713
            f.write('>60s ')
714
            for j, method in enumerate(methods):
715
                number = (timeliness_values[j][i] > 60).sum()
716
                if len(timeliness_values[j][i]) == 0:
717
                    percentage = 0.0
718
                else:
719
                    percentage = float(number)/len(timeliness_values[j][i]) * 100
720
                f.write('& %d/%.2f\\%% ' % (number, percentage))
721
            f.write('\\\\ \\hline \n')
722
723
724
def generate_timeliness_within60_table(methods, classes, result_array,
725
                                       background_class, filename=None):
726
    bg_class_id = _get_bg_class_id(classes, background_class)
727
    timeliness_values = []
728
    for i, method in enumerate(methods):
729
        start_mismatch, stop_mismatch = _get_timeoffset_measures(classes, result_array[i][0], result_array[i][1],
730
                                                                 result_array[i][4])
731
        cur_timeliness = [start_mismatch[j] + stop_mismatch[j] for j in range(len(classes))]
732
        timeliness_values.append([np.abs(np.array(cur_timeliness[j])) for j in range(len(classes))])
733
    # Average, <60, >60
734
    if filename is None:
735
        f = sys.stdout
736
    else:
737
        f = open(filename, 'w')
738
    f.write('\\textbf{Activities} ')
739
    for method in methods:
740
        f.write('& \\textbf{%s} ' % method.replace('_', ' '))
741
    f.write('\\\\ \\midrule \n')
742
    for i, activity in enumerate(classes):
743
        if i != bg_class_id:
744
            f.write('%s & ' % activity.replace('_', ' '))
745
            for j, method in enumerate(methods):
746
                number = (timeliness_values[j][i] <= 60).sum()
747
                if len(timeliness_values[j][i]) == 0:
748
                    percentage = 0.0
749
                else:
750
                    percentage = float(number)/len(timeliness_values[j][i]) * 100
751
                f.write('& %.2f\\%% ' % (percentage))
752
            f.write('\\\\ \n')
753
    f.write('\\bottomrule\n')
754
755
756 View Code Duplication
def generate_timeliness_avg_table(methods, classes, result_array,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
757
                                  background_class, filename=None):
758
    bg_class_id = _get_bg_class_id(classes, background_class)
759
    timeliness_values = []
760
    for i, method in enumerate(methods):
761
        start_mismatch, stop_mismatch = _get_timeliness_measures(classes, result_array[i][0], result_array[i][1],
762
                                                                 result_array[i][4])
763
        cur_timeliness = [start_mismatch[j] + stop_mismatch[j] for j in range(len(classes))]
764
        timeliness_values.append([np.abs(np.array(cur_timeliness[j])) for j in range(len(classes))])
765
    # Average, <60, >60
766
    if filename is None:
767
        f = sys.stdout
768
    else:
769
        f = open(filename, 'w')
770
    f.write('\\textbf{Activities} ')
771
    for method in methods:
772
        f.write('& \\textbf{%s} ' % method.replace('_', ' '))
773
    f.write('\\\\ \\midrule \n')
774
    for i, activity in enumerate(classes):
775
        if i != bg_class_id:
776
            f.write('%s ' % activity.replace('_', ' '))
777
            # Find maximum and store index
778
            for j, method in enumerate(methods):
779
                if len(timeliness_values[j][i]) == 0:
780
                    average_time = 0.
781
                else:
782
                    average_time = np.average(timeliness_values[j][i])
783
                f.write('& %.1f' % average_time)
784
            f.write('\\\\ \n')
785
    f.write('\\bottomrule \n')
786
787
788 View Code Duplication
def generate_offset_per_table(methods, classes, result_array,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
789
                              background_class, filename=None):
790
    bg_class_id = _get_bg_class_id(classes, background_class)
791
    timeliness_values = []
792
    for i, method in enumerate(methods):
793
        start_mismatch, stop_mismatch = _get_timeoffset_measures(classes, result_array[i][0], result_array[i][1],
794
                                                                 result_array[i][4])
795
        cur_timeliness = [start_mismatch[j] + stop_mismatch[j] for j in range(len(classes))]
796
        timeliness_values.append([np.abs(np.array(cur_timeliness[j])) for j in range(len(classes))])
797
    # Average, <60, >60
798
    if filename is None:
799
        f = sys.stdout
800
    else:
801
        f = open(filename, 'w')
802
    f.write('\\textbf{Activities} ')
803
    for method in methods:
804
        f.write('& \\textbf{%s} ' % method.replace('_', ' '))
805
    f.write('\\\\ \\midrule \n')
806
    for i, activity in enumerate(classes):
807
        if i != bg_class_id:
808
            f.write('%s ' % activity.replace('_', ' '))
809
            # Find maximum and store index
810
            for j, method in enumerate(methods):
811
                total_num = len(timeliness_values[j][i])/2
812
                nonzero_num = np.count_nonzero(timeliness_values[j][i])
813
                f.write('& %d/%d' % (nonzero_num, total_num))
814
            f.write('\\\\ \n')
815
    f.write('\\bottomrule \n')
816
817