Completed
Push — master ( 71853e...1f70c3 )
by Osma
16s queued 13s
created

EvaluationBatch._result_per_subject_header()   A

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 11
rs 9.85
c 0
b 0
f 0
cc 1
nop 2
1
"""Evaluation metrics for Annif"""
2
3
import collections
4
import statistics
5
import warnings
6
import numpy as np
7
from sklearn.metrics import precision_score, recall_score, f1_score
8
from sklearn.metrics import label_ranking_average_precision_score
9
from annif.exception import NotSupportedException
10
11
12
def filter_pred_top_k(preds, limit):
13
    """filter a 2D prediction vector, retaining only the top K suggestions
14
    for each individual prediction; the rest will be set to zeros"""
15
16
    masks = []
17
    for pred in preds:
18
        mask = np.zeros_like(pred, dtype=np.bool)
19
        top_k = np.argsort(pred)[::-1][:limit]
20
        mask[top_k] = True
21
        masks.append(mask)
22
    return preds * np.array(masks)
23
24
25
def true_positives(y_true, y_pred):
26
    """calculate the number of true positives using bitwise operations,
27
    emulating the way sklearn evaluation metric functions work"""
28
    return (y_true & y_pred).sum()
29
30
31
def false_positives(y_true, y_pred):
32
    """calculate the number of false positives using bitwise operations,
33
    emulating the way sklearn evaluation metric functions work"""
34
    return (~y_true & y_pred).sum()
35
36
37
def false_negatives(y_true, y_pred):
38
    """calculate the number of false negatives using bitwise operations,
39
    emulating the way sklearn evaluation metric functions work"""
40
    return (y_true & ~y_pred).sum()
41
42
43
def precision_at_k_score(y_true, y_pred, limit):
44
    """calculate the precision at K, i.e. the number of relevant items
45
    among the top K predicted ones"""
46
    scores = []
47
    for true, pred in zip(y_true, y_pred):
48
        order = pred.argsort()[::-1]
49
        orderlimit = min(limit, np.count_nonzero(pred))
50
        order = order[:orderlimit]
51
        gain = true[order]
52
        if orderlimit > 0:
53
            scores.append(gain.sum() / orderlimit)
54
        else:
55
            scores.append(0.0)
56
    return statistics.mean(scores)
57
58
59
def dcg_score(y_true, y_pred, limit=None):
60
    """return the discounted cumulative gain (DCG) score for the selected
61
    labels vs. relevant labels"""
62
    order = y_pred.argsort()[::-1]
63
    n_pred = np.count_nonzero(y_pred)
64
    if limit is not None:
65
        n_pred = min(limit, n_pred)
66
    order = order[:n_pred]
67
    gain = y_true[order]
68
    discount = np.log2(np.arange(order.size) + 2)
69
70
    return (gain / discount).sum()
71
72
73
def ndcg_score(y_true, y_pred, limit=None):
74
    """return the normalized discounted cumulative gain (nDCG) score for the
75
    selected labels vs. relevant labels"""
76
    scores = []
77
    for true, pred in zip(y_true, y_pred):
78
        idcg = dcg_score(true, true, limit)
79
        dcg = dcg_score(true, pred, limit)
80
        if idcg > 0:
81
            scores.append(dcg / idcg)
82
        else:
83
            scores.append(1.0)  # perfect score for no relevant hits case
84
    return statistics.mean(scores)
85
86
87
class EvaluationBatch:
88
    """A class for evaluating batches of results using all available metrics.
89
    The evaluate() method is called once per document in the batch.
90
    Final results can be queried using the results() method."""
91
92
    def __init__(self, subject_index):
93
        self._subject_index = subject_index
94
        self._samples = []
95
96
    def evaluate(self, hits, gold_subjects):
97
        self._samples.append((hits, gold_subjects))
98
99
    def _evaluate_samples(self, y_true, y_pred, metrics='all'):
100
        y_pred_binary = y_pred > 0.0
101
        results = collections.OrderedDict()
102
        with warnings.catch_warnings():
103
            warnings.simplefilter('ignore')
104
105
            results['Precision (doc avg)'] = precision_score(
106
                y_true, y_pred_binary, average='samples')
107
            results['Recall (doc avg)'] = recall_score(
108
                y_true, y_pred_binary, average='samples')
109
            results['F1 score (doc avg)'] = f1_score(
110
                y_true, y_pred_binary, average='samples')
111
            if metrics == 'all':
112
                results['Precision (subj avg)'] = precision_score(
113
                    y_true, y_pred_binary, average='macro')
114
                results['Recall (subj avg)'] = recall_score(
115
                    y_true, y_pred_binary, average='macro')
116
                results['F1 score (subj avg)'] = f1_score(
117
                    y_true, y_pred_binary, average='macro')
118
                results['Precision (weighted subj avg)'] = precision_score(
119
                    y_true, y_pred_binary, average='weighted')
120
                results['Recall (weighted subj avg)'] = recall_score(
121
                    y_true, y_pred_binary, average='weighted')
122
                results['F1 score (weighted subj avg)'] = f1_score(
123
                    y_true, y_pred_binary, average='weighted')
124
                results['Precision (microavg)'] = precision_score(
125
                    y_true, y_pred_binary, average='micro')
126
                results['Recall (microavg)'] = recall_score(
127
                    y_true, y_pred_binary, average='micro')
128
                results['F1 score (microavg)'] = f1_score(
129
                    y_true, y_pred_binary, average='micro')
130
            results['F1@5'] = f1_score(
131
                y_true, filter_pred_top_k(y_pred, 5) > 0.0, average='samples')
132
            results['NDCG'] = ndcg_score(y_true, y_pred)
133
            results['NDCG@5'] = ndcg_score(y_true, y_pred, limit=5)
134
            results['NDCG@10'] = ndcg_score(y_true, y_pred, limit=10)
135
            if metrics == 'all':
136
                results['Precision@1'] = precision_at_k_score(
137
                    y_true, y_pred, limit=1)
138
                results['Precision@3'] = precision_at_k_score(
139
                    y_true, y_pred, limit=3)
140
                results['Precision@5'] = precision_at_k_score(
141
                    y_true, y_pred, limit=5)
142
                results['LRAP'] = label_ranking_average_precision_score(
143
                    y_true, y_pred)
144
                results['True positives'] = true_positives(
145
                    y_true, y_pred_binary)
146
                results['False positives'] = false_positives(
147
                    y_true, y_pred_binary)
148
                results['False negatives'] = false_negatives(
149
                    y_true, y_pred_binary)
150
151
        return results
152
153
    def _result_per_subject_header(self, results_file):
154
        print('\t'.join(['URI',
155
                         'Label',
156
                         'Support',
157
                         'True_positives',
158
                         'False_positives',
159
                         'False_negatives',
160
                         'Precision',
161
                         'Recall',
162
                         'F1_score']),
163
              file=results_file)
164
165
    def _result_per_subject_body(self, zipped_results, results_file):
166
        for row in zipped_results:
167
            print('\t'.join((str(e) for e in row)), file=results_file)
0 ignored issues
show
introduced by
The variable e does not seem to be defined in case the for loop on line 166 is not entered. Are you sure this can never be the case?
Loading history...
168
169
    def output_result_per_subject(self, y_true, y_pred, results_file):
170
        """Write results per subject (non-aggregated)
171
        to outputfile results_file"""
172
173
        y_pred = y_pred.T > 0.0
174
        y_true = y_true.T > 0.0
175
176
        true_pos = (y_true & y_pred)
177
        false_pos = (~y_true & y_pred)
178
        false_neg = (y_true & ~y_pred)
179
180
        r = len(y_true)
181
182
        zipped = zip(self._subject_index._uris,               # URI
183
                     self._subject_index._labels,             # Label
184
                     np.sum((true_pos + false_neg), axis=1),  # Support
185
                     np.sum(true_pos, axis=1),                # True_positives
186
                     np.sum(false_pos, axis=1),               # False_positives
187
                     np.sum(false_neg, axis=1),               # False_negatives
188
                     [precision_score(y_true[i], y_pred[i], zero_division=0)
189
                      for i in range(r)],                     # Precision
190
                     [recall_score(y_true[i], y_pred[i], zero_division=0)
191
                      for i in range(r)],                     # Recall
192
                     [f1_score(y_true[i], y_pred[i], zero_division=0)
193
                      for i in range(r)])                     # F1
194
        self._result_per_subject_header(results_file)
195
        self._result_per_subject_body(zipped, results_file)
196
197
    def results(self, metrics='all', results_file=None):
198
        """evaluate a set of selected subjects against a gold standard using
199
        different metrics. The set of metrics can be either 'all' or 'simple'.
200
        If results_file (file object) given, write results per subject to it"""
201
202
        if not self._samples:
203
            raise NotSupportedException("cannot evaluate empty corpus")
204
205
        y_true = np.array([gold_subjects.as_vector(self._subject_index)
206
                           for hits, gold_subjects in self._samples])
207
        y_pred = np.array([hits.vector
208
                           for hits, gold_subjects in self._samples],
209
                          dtype=np.float32)
210
211
        results = self._evaluate_samples(
212
            y_true, y_pred, metrics)
213
        results['Documents evaluated'] = y_true.shape[0]
214
215
        if results_file:
216
            self.output_result_per_subject(y_true, y_pred, results_file)
217
        return results
218