Test Failed
Push — master ( 6aede2...b09ca6 )
by Richard
01:14
created

precision_recall_gain_curve()   A

Complexity

Conditions 3

Size

Total Lines 115
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 24
nop 4
dl 0
loc 115
rs 9.304
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
from functools import partial
2
3
import numpy as np
4
5
from sklearn.metrics import precision_recall_curve
6
from sklearn.metrics._ranking import _binary_clf_curve
7
from sklearn.utils.multiclass import type_of_target
8
from sklearn.metrics._base import _average_binary_score
9
10
11
def area_under_precision_recall_gain_score(
12
    y_true, y_score, *, average="macro", pos_label=1, sample_weight=None
13
):
14
    """Compute average precision (AP) from prediction scores.
15
16
    AP summarizes a precision-recall curve as the weighted mean of precisions
17
    achieved at each threshold, with the increase in recall from the previous
18
    threshold used as the weight:
19
20
    .. math::
21
        \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
22
23
    where :math:`P_n` and :math:`R_n` are the precision and recall at the nth
24
    threshold [1]_. This implementation is not interpolated and is different
25
    from computing the area under the precision-recall curve with the
26
    trapezoidal rule, which uses linear interpolation and can be too
27
    optimistic.
28
29
    Note: this implementation is restricted to the binary classification task
30
    or multilabel classification task.
31
32
    Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
33
34
    Parameters
35
    ----------
36
    y_true : ndarray of shape (n_samples,) or (n_samples, n_classes)
37
        True binary labels or binary label indicators.
38
39
    y_score : ndarray of shape (n_samples,) or (n_samples, n_classes)
40
        Target scores, can either be probability estimates of the positive
41
        class, confidence values, or non-thresholded measure of decisions
42
        (as returned by :term:`decision_function` on some classifiers).
43
44
    average : {'micro', 'samples', 'weighted', 'macro'} or None, \
45
            default='macro'
46
        If ``None``, the scores for each class are returned. Otherwise,
47
        this determines the type of averaging performed on the data:
48
49
        ``'micro'``:
50
            Calculate metrics globally by considering each element of the label
51
            indicator matrix as a label.
52
        ``'macro'``:
53
            Calculate metrics for each label, and find their unweighted
54
            mean.  This does not take label imbalance into account.
55
        ``'weighted'``:
56
            Calculate metrics for each label, and find their average, weighted
57
            by support (the number of true instances for each label).
58
        ``'samples'``:
59
            Calculate metrics for each instance, and find their average.
60
61
        Will be ignored when ``y_true`` is binary.
62
63
    pos_label : int or str, default=1
64
        The label of the positive class. Only applied to binary ``y_true``.
65
        For multilabel-indicator ``y_true``, ``pos_label`` is fixed to 1.
66
67
    sample_weight : array-like of shape (n_samples,), default=None
68
        Sample weights.
69
70
    Returns
71
    -------
72
    average_precision : float
73
74
    See Also
75
    --------
76
    roc_auc_score : Compute the area under the ROC curve.
77
    precision_recall_curve : Compute precision-recall pairs for different
78
        probability thresholds.
79
80
    Notes
81
    -----
82
    .. versionchanged:: 0.19
83
      Instead of linearly interpolating between operating points, precisions
84
      are weighted by the change in recall since the last operating point.
85
86
    References
87
    ----------
88
    .. [1] `Wikipedia entry for the Average precision
89
           <https://en.wikipedia.org/w/index.php?title=Information_retrieval&
90
           oldid=793358396#Average_precision>`_
91
92
    Examples
93
    --------
94
    >>> import numpy as np
95
    >>> from sklearn.metrics import average_precision_score
96
    >>> y_true = np.array([0, 0, 1, 1])
97
    >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
98
    >>> average_precision_score(y_true, y_scores)
99
    0.83...
100
    """
101
102
    def _binary_uninterpolated_average_precision(
103
        y_true, y_score, pos_label=1, sample_weight=None
104
    ):
105
        precision_gain, recall_gain = precision_recall_gain_curve(
106
            y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
107
        )
108
        # Return the step function integral
109
        # The following works because the last entry of precision is
110
        # guaranteed to be 1, as returned by precision_recall_curve
111
        # TODO compute integral correct?
112
        return -np.sum(np.diff(recall_gain) * np.array(precision_gain)[:-1])
113
114
    y_type = type_of_target(y_true)
115
    if y_type == "multilabel-indicator" and pos_label != 1:
116
        raise ValueError(
117
            "Parameter pos_label is fixed to 1 for "
118
            "multilabel-indicator y_true. Do not set "
119
            "pos_label or set pos_label to 1."
120
        )
121
    elif y_type == "binary":
122
        # Convert to Python primitive type to avoid NumPy type / Python str
123
        # comparison. See https://github.com/numpy/numpy/issues/6784
124
        present_labels = np.unique(y_true).tolist()
125
        if len(present_labels) == 2 and pos_label not in present_labels:
126
            raise ValueError(
127
                f"pos_label={pos_label} is not a valid label. It should be "
128
                f"one of {present_labels}"
129
            )
130
    average_precision = partial(
131
        _binary_uninterpolated_average_precision, pos_label=pos_label
132
    )
133
    # Average a binary metric for multilabel classification.
134
    average_precision = _average_binary_score(
135
        average_precision, y_true, y_score, average, sample_weight=sample_weight
136
    )
137
    return average_precision
138
139
140
def precision_recall_gain(precisions, recalls, proportion_of_positives):
141
    """
142
    Converts precision and recall into precision-gain and recall-gain.
143
144
145
    Parameters
146
    ----------
147
    proportion_of_positives: float. Proportion of positives. Termed π in the paper.
148
    precisions : ndarray
149
    recalls: ndarray
150
    """
151
152
    with np.errstate(divide="ignore", invalid="ignore"):
153
        prec_gain = (precisions - proportion_of_positives) / (
154
            (1 - proportion_of_positives) * precisions
155
        )
156
        rec_gain = (recalls - proportion_of_positives) / (
157
            (1 - proportion_of_positives) * recalls
158
        )
159
160
    return prec_gain, rec_gain
161
162
163
def precision_recall_gain_curve(y_true, probas_pred, pos_label=1, sample_weight=None):
164
    """Compute precision-recall pairs for different probability thresholds.
165
166
    Note: this implementation is restricted to the binary classification task.
167
168
    The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
169
    true positives and ``fp`` the number of false positives. The precision is
170
    intuitively the ability of the classifier not to label as positive a sample
171
    that is negative.
172
173
    The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
174
    true positives and ``fn`` the number of false negatives. The recall is
175
    intuitively the ability of the classifier to find all the positive samples.
176
177
    The last precision and recall values are 1. and 0. respectively and do not
178
    have a corresponding threshold. This ensures that the graph starts on the
179
    y axis.
180
181
    Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
182
183
    Parameters
184
    ----------
185
    y_true : ndarray of shape (n_samples,)
186
        True binary labels. If labels are not either {-1, 1} or {0, 1}, then
187
        pos_label should be explicitly given.
188
189
    probas_pred : ndarray of shape (n_samples,)
190
        Estimated probabilities or output of a decision function.
191
192
    pos_label : int or str, default=None
193
        The label of the positive class.
194
        When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
195
        ``pos_label`` is set to 1, otherwise an error will be raised.
196
197
    sample_weight : array-like of shape (n_samples,), default=None
198
        Sample weights.
199
200
    Returns
201
    -------
202
    precision : ndarray of shape (n_thresholds + 1,)
203
        Precision values such that element i is the precision of
204
        predictions with score >= thresholds[i] and the last element is 1.
205
206
    recall : ndarray of shape (n_thresholds + 1,)
207
        Decreasing recall values such that element i is the recall of
208
        predictions with score >= thresholds[i] and the last element is 0.
209
210
    thresholds : ndarray of shape (n_thresholds,)
211
        Increasing thresholds on the decision function used to compute
212
        precision and recall. n_thresholds <= len(np.unique(probas_pred)).
213
214
    See Also
215
    --------
216
    plot_precision_recall_curve : Plot Precision Recall Curve for binary
217
        classifiers.
218
    PrecisionRecallDisplay : Precision Recall visualization.
219
    average_precision_score : Compute average precision from prediction scores.
220
    det_curve: Compute error rates for different probability thresholds.
221
    roc_curve : Compute Receiver operating characteristic (ROC) curve.
222
223
    Examples
224
    --------
225
    >>> import numpy as np
226
    >>> from sklearn.metrics import precision_recall_curve
227
    >>> y_true = np.array([0, 0, 1, 1])
228
    >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
229
    >>> precision, recall, thresholds = precision_recall_curve(
230
    ...     y_true, y_scores)
231
    >>> precision
232
    array([0.66666667, 0.5       , 1.        , 1.        ])
233
    >>> recall
234
    array([1. , 0.5, 0.5, 0. ])
235
    >>> thresholds
236
    array([0.35, 0.4 , 0.8 ])
237
238
    """
239
    if pos_label != 1:
240
        raise NotImplementedError("Have not implemented non-binary targets")
241
    if sample_weight is not None:
242
        raise NotImplementedError
243
244
    # calc true and false poitives per binary classification thresh
245
    fps, tps, thresholds = _binary_clf_curve(
246
        y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
247
    )
248
249
    precision = tps / (tps + fps)
250
    precision[np.isnan(precision)] = 0
251
    recall = tps / tps[-1]
252
253
    # stop when full recall attained
254
    # and reverse the outputs so recall is decreasing
255
    last_ind = tps.searchsorted(tps[-1])
256
    sl = slice(last_ind, None, -1)  # equivalent to slice [last_ind:None:-1]
257
    precision, recall, thresholds = (
258
        np.r_[precision[sl], 1],
259
        np.r_[recall[sl], 0],
260
        thresholds[sl],
261
    )
262
263
    # everything above is taken from sklearn.metrics._ranking.precision_recall_curve
264
265
    # logic taken from sklearn.metrics._ranking.det_curve
266
    # fns = tps[-1] - tps
267
    p_count = tps[-1]
268
    n_count = fps[-1]
269
    proportion_of_positives = p_count / n_count
270
271
    precision_gains, recall_gains = precision_recall_gain(
272
        precisions=precision,
273
        recalls=recall,
274
        proportion_of_positives=proportion_of_positives,
275
    )
276
277
    return precision_gains, recall_gains
278