precision_recall_gain_curve()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 115
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 24
nop 4
dl 0
loc 115
rs 9.304
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
from functools import partial
2
3
import numpy as np
4
from sklearn.metrics._base import _average_binary_score
5
from sklearn.metrics._ranking import _binary_clf_curve
6
from sklearn.utils.multiclass import type_of_target
7
8
9
def area_under_precision_recall_gain_score(
10
    y_true, y_score, *, average="macro", pos_label=1, sample_weight=None
11
):
12
    """Compute average precision (AP) from prediction scores.
13
14
    AP summarizes a precision-recall curve as the weighted mean of precisions
15
    achieved at each threshold, with the increase in recall from the previous
16
    threshold used as the weight:
17
18
    .. math::
19
        \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
20
21
    where :math:`P_n` and :math:`R_n` are the precision and recall at the nth
22
    threshold [1]_. This implementation is not interpolated and is different
23
    from computing the area under the precision-recall curve with the
24
    trapezoidal rule, which uses linear interpolation and can be too
25
    optimistic.
26
27
    Note: this implementation is restricted to the binary classification task
28
    or multilabel classification task.
29
30
    Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
31
32
    Parameters
33
    ----------
34
    y_true : ndarray of shape (n_samples,) or (n_samples, n_classes)
35
        True binary labels or binary label indicators.
36
37
    y_score : ndarray of shape (n_samples,) or (n_samples, n_classes)
38
        Target scores, can either be probability estimates of the positive
39
        class, confidence values, or non-thresholded measure of decisions
40
        (as returned by :term:`decision_function` on some classifiers).
41
42
    average : {'micro', 'samples', 'weighted', 'macro'} or None, \
43
            default='macro'
44
        If ``None``, the scores for each class are returned. Otherwise,
45
        this determines the type of averaging performed on the data:
46
47
        ``'micro'``:
48
            Calculate metrics globally by considering each element of the label
49
            indicator matrix as a label.
50
        ``'macro'``:
51
            Calculate metrics for each label, and find their unweighted
52
            mean.  This does not take label imbalance into account.
53
        ``'weighted'``:
54
            Calculate metrics for each label, and find their average, weighted
55
            by support (the number of true instances for each label).
56
        ``'samples'``:
57
            Calculate metrics for each instance, and find their average.
58
59
        Will be ignored when ``y_true`` is binary.
60
61
    pos_label : int or str, default=1
62
        The label of the positive class. Only applied to binary ``y_true``.
63
        For multilabel-indicator ``y_true``, ``pos_label`` is fixed to 1.
64
65
    sample_weight : array-like of shape (n_samples,), default=None
66
        Sample weights.
67
68
    Returns
69
    -------
70
    average_precision : float
71
72
    See Also
73
    --------
74
    roc_auc_score : Compute the area under the ROC curve.
75
    precision_recall_curve : Compute precision-recall pairs for different
76
        probability thresholds.
77
78
    Notes
79
    -----
80
    .. versionchanged:: 0.19
81
      Instead of linearly interpolating between operating points, precisions
82
      are weighted by the change in recall since the last operating point.
83
84
    References
85
    ----------
86
    .. [1] `Wikipedia entry for the Average precision
87
           <https://en.wikipedia.org/w/index.php?title=Information_retrieval&
88
           oldid=793358396#Average_precision>`_
89
90
    Examples
91
    --------
92
    >>> import numpy as np
93
    >>> from precision_recall_gain import average_precision_score
94
    >>> y_true = np.array([0, 0, 1, 1])
95
    >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
96
    >>> average_precision_score(y_true, y_scores)
97
    0.83...
98
    """
99
100
    def _binary_uninterpolated_average_precision(
101
        y_true, y_score, pos_label=1, sample_weight=None
102
    ):
103
        precision_gain, recall_gain = precision_recall_gain_curve(
104
            y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
105
        )
106
        # Return the step function integral
107
        # The following works because the last entry of precision is
108
        # guaranteed to be 1, as returned by precision_recall_curve
109
        # TODO compute integral correct?
110
        return -np.sum(np.diff(recall_gain) * np.array(precision_gain)[:-1])
111
112
    y_type = type_of_target(y_true)
113
    if y_type == "multilabel-indicator" and pos_label != 1:
114
        raise ValueError(
115
            "Parameter pos_label is fixed to 1 for "
116
            "multilabel-indicator y_true. Do not set "
117
            "pos_label or set pos_label to 1."
118
        )
119
    elif y_type == "binary":
120
        # Convert to Python primitive type to avoid NumPy type / Python str
121
        # comparison. See https://github.com/numpy/numpy/issues/6784
122
        present_labels = np.unique(y_true).tolist()
123
        if len(present_labels) == 2 and pos_label not in present_labels:
124
            raise ValueError(
125
                f"pos_label={pos_label} is not a valid label. It should be "
126
                f"one of {present_labels}"
127
            )
128
    average_precision = partial(
129
        _binary_uninterpolated_average_precision, pos_label=pos_label
130
    )
131
    # Average a binary metric for multilabel classification.
132
    average_precision = _average_binary_score(
133
        average_precision, y_true, y_score, average, sample_weight=sample_weight
134
    )
135
    return average_precision
136
137
138
def precision_recall_gain(precisions, recalls, proportion_of_positives):
139
    """
140
    Converts precision and recall into precision-gain and recall-gain.
141
142
143
    Parameters
144
    ----------
145
    proportion_of_positives: float. Proportion of positives. Termed π in the paper.
146
    precisions : ndarray
147
    recalls: ndarray
148
    """
149
150
    with np.errstate(divide="ignore", invalid="ignore"):
151
        prec_gain = (precisions - proportion_of_positives) / (
152
            (1 - proportion_of_positives) * precisions
153
        )
154
        rec_gain = (recalls - proportion_of_positives) / (
155
            (1 - proportion_of_positives) * recalls
156
        )
157
158
    return prec_gain, rec_gain
159
160
161
def precision_recall_gain_curve(y_true, probas_pred, pos_label=1, sample_weight=None):
162
    """Compute precision-recall pairs for different probability thresholds.
163
164
    Note: this implementation is restricted to the binary classification task.
165
166
    The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
167
    true positives and ``fp`` the number of false positives. The precision is
168
    intuitively the ability of the classifier not to label as positive a sample
169
    that is negative.
170
171
    The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
172
    true positives and ``fn`` the number of false negatives. The recall is
173
    intuitively the ability of the classifier to find all the positive samples.
174
175
    The last precision and recall values are 1. and 0. respectively and do not
176
    have a corresponding threshold. This ensures that the graph starts on the
177
    y axis.
178
179
    Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
180
181
    Parameters
182
    ----------
183
    y_true : ndarray of shape (n_samples,)
184
        True binary labels. If labels are not either {-1, 1} or {0, 1}, then
185
        pos_label should be explicitly given.
186
187
    probas_pred : ndarray of shape (n_samples,)
188
        Estimated probabilities or output of a decision function.
189
190
    pos_label : int or str, default=None
191
        The label of the positive class.
192
        When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
193
        ``pos_label`` is set to 1, otherwise an error will be raised.
194
195
    sample_weight : array-like of shape (n_samples,), default=None
196
        Sample weights.
197
198
    Returns
199
    -------
200
    precision : ndarray of shape (n_thresholds + 1,)
201
        Precision values such that element i is the precision of
202
        predictions with score >= thresholds[i] and the last element is 1.
203
204
    recall : ndarray of shape (n_thresholds + 1,)
205
        Decreasing recall values such that element i is the recall of
206
        predictions with score >= thresholds[i] and the last element is 0.
207
208
    thresholds : ndarray of shape (n_thresholds,)
209
        Increasing thresholds on the decision function used to compute
210
        precision and recall. n_thresholds <= len(np.unique(probas_pred)).
211
212
    See Also
213
    --------
214
    plot_precision_recall_curve : Plot Precision Recall Curve for binary
215
        classifiers.
216
    PrecisionRecallDisplay : Precision Recall visualization.
217
    average_precision_score : Compute average precision from prediction scores.
218
    det_curve: Compute error rates for different probability thresholds.
219
    roc_curve : Compute Receiver operating characteristic (ROC) curve.
220
221
    Examples
222
    --------
223
    >>> import numpy as np
224
    >>> from precision_recall_gain import precision_recall_curve
225
    >>> y_true = np.array([0, 0, 1, 1])
226
    >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
227
    >>> precision, recall, thresholds = precision_recall_curve(
228
    ...     y_true, y_scores)
229
    >>> precision
230
    array([0.66666667, 0.5       , 1.        , 1.        ])
231
    >>> recall
232
    array([1. , 0.5, 0.5, 0. ])
233
    >>> thresholds
234
    array([0.35, 0.4 , 0.8 ])
235
236
    """
237
    if pos_label != 1:
238
        raise NotImplementedError("Have not implemented non-binary targets")
239
    if sample_weight is not None:
240
        raise NotImplementedError
241
242
    # calc true and false poitives per binary classification thresh
243
    fps, tps, thresholds = _binary_clf_curve(
244
        y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
245
    )
246
247
    precision = tps / (tps + fps)
248
    precision[np.isnan(precision)] = 0
249
    recall = tps / tps[-1]
250
251
    # stop when full recall attained
252
    # and reverse the outputs so recall is decreasing
253
    last_ind = tps.searchsorted(tps[-1])
254
    sl = slice(last_ind, None, -1)  # equivalent to slice [last_ind:None:-1]
255
    precision, recall, thresholds = (
256
        np.r_[precision[sl], 1],
257
        np.r_[recall[sl], 0],
258
        thresholds[sl],
259
    )
260
261
    # everything above is taken from sklearn.metrics._ranking.precision_recall_curve
262
263
    # logic taken from sklearn.metrics._ranking.det_curve
264
    # fns = tps[-1] - tps
265
    p_count = tps[-1]
266
    n_count = fps[-1]
267
    proportion_of_positives = p_count / n_count
268
269
    precision_gains, recall_gains = precision_recall_gain(
270
        precisions=precision,
271
        recalls=recall,
272
        proportion_of_positives=proportion_of_positives,
273
    )
274
275
    return precision_gains, recall_gains
276
277
278
"""
279
Source:
280
https://github.com/meeliskull/prg/blob/master/Python_package/prg/prg.py
281
"""
282
283
284
def precision(tp, fn, fp, tn):
285
    with np.errstate(divide="ignore", invalid="ignore"):
286
        return tp / (tp + fp)
287
288
289
def recall(tp, fn, fp, tn):
290
    with np.errstate(divide="ignore", invalid="ignore"):
291
        return tp / (tp + fn)
292
293
294 View Code Duplication
def precision_gain(tp, fn, fp, tn):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
295
    """Calculates Precision Gain from the contingency table
296
297
    This function calculates Precision Gain from the entries of the contingency
298
    table: number of true positives (TP), false negatives (FN), false positives
299
    (FP), and true negatives (TN). More information on Precision-Recall-Gain
300
    curves and how to cite this work is available at
301
    http://www.cs.bris.ac.uk/~flach/PRGcurves/.
302
    """
303
    n_pos = tp + fn
304
    n_neg = fp + tn
305
    with np.errstate(divide="ignore", invalid="ignore"):
306
        prec_gain = 1.0 - (n_pos / n_neg) * (fp / tp)
307
    if np.alen(prec_gain) > 1:
308
        prec_gain[tn + fn == 0] = 0
309
    elif tn + fn == 0:
310
        prec_gain = 0
311
    return prec_gain
312
313
314 View Code Duplication
def recall_gain(tp, fn, fp, tn):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
315
    """Calculates Recall Gain from the contingency table
316
317
    This function calculates Recall Gain from the entries of the contingency
318
    table: number of true positives (TP), false negatives (FN), false positives
319
    (FP), and true negatives (TN). More information on Precision-Recall-Gain
320
    curves and how to cite this work is available at
321
    http://www.cs.bris.ac.uk/~flach/PRGcurves/.
322
323
    Args:
324
        tp (float) or ([float]): True Positives
325
        fn (float) or ([float]): False Negatives
326
        fp (float) or ([float]): False Positives
327
        tn (float) or ([float]): True Negatives
328
    Returns:
329
        (float) or ([float])
330
    """
331
    n_pos = tp + fn
332
    n_neg = fp + tn
333
    with np.errstate(divide="ignore", invalid="ignore"):
334
        rg = 1.0 - (n_pos / n_neg) * (fn / tp)
335
    if np.alen(rg) > 1:
336
        rg[tn + fn == 0] = 1
337
    elif tn + fn == 0:
338
        rg = 1
339
    return rg
340