annif.eval.EvaluationBatch._evaluate_samples()   F
last analyzed

Complexity

Conditions 25

Size

Total Lines 74
Code Lines 50

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 25
eloc 50
nop 4
dl 0
loc 74
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like annif.eval.EvaluationBatch._evaluate_samples() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Evaluation metrics for Annif"""
2
3
from __future__ import annotations
4
5
import warnings
6
from typing import TYPE_CHECKING
7
8
import numpy as np
9
import scipy.sparse
10
from sklearn.metrics import f1_score, precision_score, recall_score
11
12
from annif.exception import NotSupportedException
13
from annif.suggestion import SuggestionBatch, filter_suggestion
14
15
if TYPE_CHECKING:
16
    from collections.abc import Iterable, Iterator, Sequence
17
    from io import TextIOWrapper
18
19
    from click.utils import LazyFile
20
    from scipy.sparse._arrays import csr_array
21
22
    from annif.corpus.subject import SubjectIndex, SubjectSet
23
    from annif.suggestion import SubjectSuggestion
24
25
26
def true_positives(y_true: csr_array, y_pred: csr_array) -> int:
27
    """calculate the number of true positives using bitwise operations,
28
    emulating the way sklearn evaluation metric functions work"""
29
    return int((y_true.multiply(y_pred)).sum())
30
31
32
def false_positives(y_true: csr_array, y_pred: csr_array) -> int:
33
    """calculate the number of false positives using bitwise operations,
34
    emulating the way sklearn evaluation metric functions work"""
35
    return int((y_true < y_pred).sum())
36
37
38
def false_negatives(y_true: csr_array, y_pred: csr_array) -> int:
39
    """calculate the number of false negatives using bitwise operations,
40
    emulating the way sklearn evaluation metric functions work"""
41
    return int((y_true > y_pred).sum())
42
43
44
def dcg_score(
45
    y_true: csr_array, y_pred: csr_array, limit: int | None = None
46
) -> np.float64:
47
    """return the discounted cumulative gain (DCG) score for the selected
48
    labels vs. relevant labels"""
49
50
    n_pred = y_pred.count_nonzero()
51
    if limit is not None:
52
        n_pred = min(limit, n_pred)
53
54
    top_k = y_pred.data.argsort()[-n_pred:][::-1]
55
    order = y_pred.indices[top_k]
56
    gain = y_true[:, order]
57
    discount = np.log2(np.arange(1, n_pred + 1) + 1)
58
    return (gain / discount).sum()
59
60
61
def ndcg_score(y_true: csr_array, y_pred: csr_array, limit: int | None = None) -> float:
62
    """return the normalized discounted cumulative gain (nDCG) score for the
63
    selected labels vs. relevant labels"""
64
65
    scores = np.ones(y_true.shape[0], dtype=np.float32)
66
    for i in range(y_true.shape[0]):
67
        true = y_true[[i]]
68
        idcg = dcg_score(true, true, limit)
69
        if idcg > 0:
70
            pred = y_pred[[i]]
71
            dcg = dcg_score(true, pred, limit)
72
            scores[i] = dcg / idcg
73
74
    return float(scores.mean())
75
76
77
class EvaluationBatch:
78
    """A class for evaluating batches of results using all available metrics.
79
    The evaluate() method is called once per document in the batch or evaluate_many()
80
    for a list of documents of the batch. Final results can be queried using the
81
    results() method."""
82
83
    def __init__(self, subject_index: SubjectIndex) -> None:
84
        self._subject_index = subject_index
85
        self._suggestion_arrays = []
86
        self._gold_subject_arrays = []
87
88
    def evaluate_many(
89
        self,
90
        suggestion_batch: (
91
            list[list[SubjectSuggestion]] | SuggestionBatch | list[Iterator]
92
        ),
93
        gold_subject_batch: Sequence[SubjectSet],
94
    ) -> None:
95
        if not isinstance(suggestion_batch, SuggestionBatch):
96
            suggestion_batch = SuggestionBatch.from_sequence(
97
                suggestion_batch, self._subject_index
98
            )
99
        self._suggestion_arrays.append(suggestion_batch.array)
100
101
        # convert gold_subject_batch to sparse matrix
102
        ar = scipy.sparse.dok_array(
103
            (len(gold_subject_batch), len(self._subject_index)), dtype=bool
104
        )
105
        for idx, subject_set in enumerate(gold_subject_batch):
106
            for subject_id in subject_set:
107
                ar[idx, subject_id] = True
108
        self._gold_subject_arrays.append(ar.tocsr())
109
110
    def _evaluate_samples(
111
        self,
112
        y_true: csr_array,
113
        y_pred: csr_array,
114
        metrics: Iterable[str] = [],
115
    ) -> dict[str, float]:
116
        y_pred_binary = y_pred > 0.0
117
118
        # define the available metrics as lazy lambda functions
119
        # so we can execute only the ones actually requested
120
        all_metrics = {
121
            "Precision (doc avg)": lambda: precision_score(
122
                y_true, y_pred_binary, average="samples"
123
            ),
124
            "Recall (doc avg)": lambda: recall_score(
125
                y_true, y_pred_binary, average="samples"
126
            ),
127
            "F1 score (doc avg)": lambda: f1_score(
128
                y_true, y_pred_binary, average="samples"
129
            ),
130
            "Precision (subj avg)": lambda: precision_score(
131
                y_true, y_pred_binary, average="macro"
132
            ),
133
            "Recall (subj avg)": lambda: recall_score(
134
                y_true, y_pred_binary, average="macro"
135
            ),
136
            "F1 score (subj avg)": lambda: f1_score(
137
                y_true, y_pred_binary, average="macro"
138
            ),
139
            "Precision (weighted subj avg)": lambda: precision_score(
140
                y_true, y_pred_binary, average="weighted"
141
            ),
142
            "Recall (weighted subj avg)": lambda: recall_score(
143
                y_true, y_pred_binary, average="weighted"
144
            ),
145
            "F1 score (weighted subj avg)": lambda: f1_score(
146
                y_true, y_pred_binary, average="weighted"
147
            ),
148
            "Precision (microavg)": lambda: precision_score(
149
                y_true, y_pred_binary, average="micro"
150
            ),
151
            "Recall (microavg)": lambda: recall_score(
152
                y_true, y_pred_binary, average="micro"
153
            ),
154
            "F1 score (microavg)": lambda: f1_score(
155
                y_true, y_pred_binary, average="micro"
156
            ),
157
            "F1@5": lambda: f1_score(
158
                y_true, filter_suggestion(y_pred, 5) > 0.0, average="samples"
159
            ),
160
            "NDCG": lambda: ndcg_score(y_true, y_pred),
161
            "NDCG@5": lambda: ndcg_score(y_true, y_pred, limit=5),
162
            "NDCG@10": lambda: ndcg_score(y_true, y_pred, limit=10),
163
            "Precision@1": lambda: precision_score(
164
                y_true, filter_suggestion(y_pred, 1) > 0.0, average="samples"
165
            ),
166
            "Precision@3": lambda: precision_score(
167
                y_true, filter_suggestion(y_pred, 3) > 0.0, average="samples"
168
            ),
169
            "Precision@5": lambda: precision_score(
170
                y_true, filter_suggestion(y_pred, 5) > 0.0, average="samples"
171
            ),
172
            "True positives": lambda: true_positives(y_true, y_pred_binary),
173
            "False positives": lambda: false_positives(y_true, y_pred_binary),
174
            "False negatives": lambda: false_negatives(y_true, y_pred_binary),
175
        }
176
177
        if not metrics:
178
            metrics = all_metrics.keys()
179
180
        with warnings.catch_warnings():
181
            warnings.simplefilter("ignore")
182
183
            return {metric: all_metrics[metric]() for metric in metrics}
184
185
    def _result_per_subject_header(
186
        self, results_file: LazyFile | TextIOWrapper
187
    ) -> None:
188
        print(
189
            "\t".join(
190
                [
191
                    "URI",
192
                    "Label",
193
                    "Support",
194
                    "True_positives",
195
                    "False_positives",
196
                    "False_negatives",
197
                    "Precision",
198
                    "Recall",
199
                    "F1_score",
200
                ]
201
            ),
202
            file=results_file,
203
        )
204
205
    def _result_per_subject_body(
206
        self, zipped_results: zip, results_file: LazyFile | TextIOWrapper
207
    ) -> None:
208
        for row in zipped_results:
209
            print("\t".join((str(e) for e in row)), file=results_file)
0 ignored issues
show
introduced by
The variable e does not seem to be defined in case the for loop on line 208 is not entered. Are you sure this can never be the case?
Loading history...
210
211
    def output_result_per_subject(
212
        self,
213
        y_true: csr_array,
214
        y_pred: csr_array,
215
        results_file: TextIOWrapper | LazyFile,
216
        language: str,
217
    ) -> None:
218
        """Write results per subject (non-aggregated)
219
        to outputfile results_file, using labels in the given language"""
220
221
        y_pred = y_pred.T > 0.0
222
        y_true = y_true.T
223
224
        true_pos = y_true.multiply(y_pred).sum(axis=1)
225
        false_pos = (y_true < y_pred).sum(axis=1)
226
        false_neg = (y_true > y_pred).sum(axis=1)
227
228
        with np.errstate(invalid="ignore"):
229
            precision = np.nan_to_num(true_pos / (true_pos + false_pos))
230
            recall = np.nan_to_num(true_pos / (true_pos + false_neg))
231
            f1_score = np.nan_to_num(2 * (precision * recall) / (precision + recall))
232
233
        zipped = zip(
234
            [subj.uri for subj in self._subject_index],  # URI
235
            [subj.labels[language] for subj in self._subject_index],  # Label
236
            y_true.sum(axis=1),  # Support
237
            true_pos,  # True positives
238
            false_pos,  # False positives
239
            false_neg,  # False negatives
240
            precision,  # Precision
241
            recall,  # Recall
242
            f1_score,  # F1 score
243
        )
244
        self._result_per_subject_header(results_file)
245
        self._result_per_subject_body(zipped, results_file)
246
247
    def results(
248
        self,
249
        metrics: Iterable[str] = [],
250
        results_file: LazyFile | TextIOWrapper | None = None,
251
        language: str | None = None,
252
    ) -> dict[str, float]:
253
        """evaluate a set of selected subjects against a gold standard using
254
        different metrics. If metrics is empty, use all available metrics.
255
        If results_file (file object) given, write results per subject to it
256
        with labels expressed in the given language."""
257
258
        if not self._suggestion_arrays:
259
            raise NotSupportedException("cannot evaluate empty corpus")
260
261
        y_pred = scipy.sparse.csr_array(scipy.sparse.vstack(self._suggestion_arrays))
262
        y_true = scipy.sparse.csr_array(scipy.sparse.vstack(self._gold_subject_arrays))
263
264
        results = self._evaluate_samples(y_true, y_pred, metrics)
265
        results["Documents evaluated"] = int(y_true.shape[0])
266
267
        if results_file:
268
            self.output_result_per_subject(y_true, y_pred, results_file, language)
269
        return results
270