@@ 263-325 (lines=63) @@ | ||
260 | return indices |
|
261 | ||
262 | ||
263 | def find_independence_thresholds(roc_curves, base_rates, proportions, |
|
264 | cost_matrix): |
|
265 | """Compute thresholds that achieve independence and minimize cost. |
|
266 | ||
267 | :param roc_curves: Receiver operating characteristic (ROC) |
|
268 | by attribute. |
|
269 | :type roc_curves: dict |
|
270 | :param base_rates: Base rate by attribute. |
|
271 | :type base_rates: dict |
|
272 | :param proportions: Proportion of each attribute value. |
|
273 | :type proportions: dict |
|
274 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
275 | :type cost_matrix: sequence |
|
276 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
277 | :rtype: tuple |
|
278 | ||
279 | """ |
|
280 | ||
281 | cutoffs = {} |
|
282 | ||
283 | def total_cost_function(acceptance_rate_value): |
|
284 | # todo: move demo here + multiple cost |
|
285 | # + refactor - use threshold to calculate |
|
286 | # acceptance_rate_value |
|
287 | indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
288 | acceptance_rate_value) |
|
289 | ||
290 | total_cost = 0 |
|
291 | ||
292 | for group, roc in roc_curves.items(): |
|
293 | index = indices[group] |
|
294 | ||
295 | fpr = roc[0][index] |
|
296 | tpr = roc[1][index] |
|
297 | ||
298 | group_cost = _cost_function(fpr, tpr, |
|
299 | base_rates[group], |
|
300 | cost_matrix) |
|
301 | ||
302 | group_cost *= proportions[group] |
|
303 | ||
304 | total_cost += group_cost |
|
305 | ||
306 | return -total_cost |
|
307 | ||
308 | acceptance_rate_min_cost = _ternary_search_float(total_cost_function, |
|
309 | 0, 1, TRINARY_SEARCH_TOL) |
|
310 | ||
311 | cost = total_cost_function(acceptance_rate_min_cost) |
|
312 | ||
313 | threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
314 | acceptance_rate_min_cost) |
|
315 | thresholds = _extract_threshold(roc_curves) |
|
316 | ||
317 | cutoffs = {group: thresholds[threshold_index] |
|
318 | for group, threshold_index |
|
319 | in threshold_indices.items()} |
|
320 | ||
321 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
322 | roc[1][threshold_indices[group]]) |
|
323 | for group, roc in roc_curves.items()} |
|
324 | ||
325 | return cutoffs, fpr_tpr, cost, acceptance_rate_min_cost |
|
326 | ||
327 | ||
328 | def get_fnr_indices(roc_curves, fnr_value): |
|
@@ 343-403 (lines=61) @@ | ||
340 | return indices |
|
341 | ||
342 | ||
343 | def find_fnr_thresholds(roc_curves, base_rates, proportions, |
|
344 | cost_matrix): |
|
345 | """Compute thresholds that achieve equal FNRs and minimize cost. |
|
346 | ||
347 | Also known as **equal opportunity**. |
|
348 | ||
349 | :param roc_curves: Receiver operating characteristic (ROC) |
|
350 | by attribute. |
|
351 | :type roc_curves: dict |
|
352 | :param base_rates: Base rate by attribute. |
|
353 | :type base_rates: dict |
|
354 | :param proportions: Proportion of each attribute value. |
|
355 | :type proportions: dict |
|
356 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
357 | :type cost_matrix: sequence |
|
358 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
359 | :rtype: tuple |
|
360 | ||
361 | """ |
|
362 | ||
363 | cutoffs = {} |
|
364 | ||
365 | def total_cost_function(fnr_value): |
|
366 | # todo: move demo here + multiple cost |
|
367 | indices = get_fnr_indices(roc_curves, fnr_value) |
|
368 | ||
369 | total_cost = 0 |
|
370 | ||
371 | for group, roc in roc_curves.items(): |
|
372 | index = indices[group] |
|
373 | ||
374 | fpr = roc[0][index] |
|
375 | tpr = roc[1][index] |
|
376 | ||
377 | group_cost = _cost_function(fpr, tpr, |
|
378 | base_rates[group], |
|
379 | cost_matrix) |
|
380 | group_cost *= proportions[group] |
|
381 | ||
382 | total_cost += group_cost |
|
383 | ||
384 | return -total_cost |
|
385 | ||
386 | fnr_value_min_cost = _ternary_search_float(total_cost_function, |
|
387 | 0, 1, |
|
388 | TRINARY_SEARCH_TOL) |
|
389 | ||
390 | threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost) |
|
391 | ||
392 | cost = total_cost_function(fnr_value_min_cost) |
|
393 | ||
394 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
395 | roc[1][threshold_indices[group]]) |
|
396 | for group, roc in roc_curves.items()} |
|
397 | ||
398 | thresholds = _extract_threshold(roc_curves) |
|
399 | cutoffs = {group: thresholds[threshold_index] |
|
400 | for group, threshold_index |
|
401 | in threshold_indices.items()} |
|
402 | ||
403 | return cutoffs, fpr_tpr, cost, fnr_value_min_cost |
|
404 | ||
405 | ||
406 | def _find_feasible_roc(roc_curves): |