Conditions | 3 |
Total Lines | 115 |
Code Lines | 24 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | from functools import partial |
||
161 | def precision_recall_gain_curve(y_true, probas_pred, pos_label=1, sample_weight=None): |
||
162 | """Compute precision-recall pairs for different probability thresholds. |
||
163 | |||
164 | Note: this implementation is restricted to the binary classification task. |
||
165 | |||
166 | The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of |
||
167 | true positives and ``fp`` the number of false positives. The precision is |
||
168 | intuitively the ability of the classifier not to label as positive a sample |
||
169 | that is negative. |
||
170 | |||
171 | The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of |
||
172 | true positives and ``fn`` the number of false negatives. The recall is |
||
173 | intuitively the ability of the classifier to find all the positive samples. |
||
174 | |||
175 | The last precision and recall values are 1. and 0. respectively and do not |
||
176 | have a corresponding threshold. This ensures that the graph starts on the |
||
177 | y axis. |
||
178 | |||
179 | Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`. |
||
180 | |||
181 | Parameters |
||
182 | ---------- |
||
183 | y_true : ndarray of shape (n_samples,) |
||
184 | True binary labels. If labels are not either {-1, 1} or {0, 1}, then |
||
185 | pos_label should be explicitly given. |
||
186 | |||
187 | probas_pred : ndarray of shape (n_samples,) |
||
188 | Estimated probabilities or output of a decision function. |
||
189 | |||
190 | pos_label : int or str, default=None |
||
191 | The label of the positive class. |
||
192 | When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, |
||
193 | ``pos_label`` is set to 1, otherwise an error will be raised. |
||
194 | |||
195 | sample_weight : array-like of shape (n_samples,), default=None |
||
196 | Sample weights. |
||
197 | |||
198 | Returns |
||
199 | ------- |
||
200 | precision : ndarray of shape (n_thresholds + 1,) |
||
201 | Precision values such that element i is the precision of |
||
202 | predictions with score >= thresholds[i] and the last element is 1. |
||
203 | |||
204 | recall : ndarray of shape (n_thresholds + 1,) |
||
205 | Decreasing recall values such that element i is the recall of |
||
206 | predictions with score >= thresholds[i] and the last element is 0. |
||
207 | |||
208 | thresholds : ndarray of shape (n_thresholds,) |
||
209 | Increasing thresholds on the decision function used to compute |
||
210 | precision and recall. n_thresholds <= len(np.unique(probas_pred)). |
||
211 | |||
212 | See Also |
||
213 | -------- |
||
214 | plot_precision_recall_curve : Plot Precision Recall Curve for binary |
||
215 | classifiers. |
||
216 | PrecisionRecallDisplay : Precision Recall visualization. |
||
217 | average_precision_score : Compute average precision from prediction scores. |
||
218 | det_curve: Compute error rates for different probability thresholds. |
||
219 | roc_curve : Compute Receiver operating characteristic (ROC) curve. |
||
220 | |||
221 | Examples |
||
222 | -------- |
||
223 | >>> import numpy as np |
||
224 | >>> from precision_recall_gain import precision_recall_curve |
||
225 | >>> y_true = np.array([0, 0, 1, 1]) |
||
226 | >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) |
||
227 | >>> precision, recall, thresholds = precision_recall_curve( |
||
228 | ... y_true, y_scores) |
||
229 | >>> precision |
||
230 | array([0.66666667, 0.5 , 1. , 1. ]) |
||
231 | >>> recall |
||
232 | array([1. , 0.5, 0.5, 0. ]) |
||
233 | >>> thresholds |
||
234 | array([0.35, 0.4 , 0.8 ]) |
||
235 | |||
236 | """ |
||
237 | if pos_label != 1: |
||
238 | raise NotImplementedError("Have not implemented non-binary targets") |
||
239 | if sample_weight is not None: |
||
240 | raise NotImplementedError |
||
241 | |||
242 | # calc true and false poitives per binary classification thresh |
||
243 | fps, tps, thresholds = _binary_clf_curve( |
||
244 | y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight |
||
245 | ) |
||
246 | |||
247 | precision = tps / (tps + fps) |
||
248 | precision[np.isnan(precision)] = 0 |
||
249 | recall = tps / tps[-1] |
||
250 | |||
251 | # stop when full recall attained |
||
252 | # and reverse the outputs so recall is decreasing |
||
253 | last_ind = tps.searchsorted(tps[-1]) |
||
254 | sl = slice(last_ind, None, -1) # equivalent to slice [last_ind:None:-1] |
||
255 | precision, recall, thresholds = ( |
||
256 | np.r_[precision[sl], 1], |
||
257 | np.r_[recall[sl], 0], |
||
258 | thresholds[sl], |
||
259 | ) |
||
260 | |||
261 | # everything above is taken from sklearn.metrics._ranking.precision_recall_curve |
||
262 | |||
263 | # logic taken from sklearn.metrics._ranking.det_curve |
||
264 | # fns = tps[-1] - tps |
||
265 | p_count = tps[-1] |
||
266 | n_count = fps[-1] |
||
267 | proportion_of_positives = p_count / n_count |
||
268 | |||
269 | precision_gains, recall_gains = precision_recall_gain( |
||
270 | precisions=precision, |
||
271 | recalls=recall, |
||
272 | proportion_of_positives=proportion_of_positives, |
||
273 | ) |
||
274 | |||
275 | return precision_gains, recall_gains |
||
276 | |||
340 |