Code Duplication    Length = 61-63 lines in 2 locations

responsibly/fairness/interventions/threshold.py 2 locations

@@ 263-325 (lines=63) @@
260
    return indices
261
262
263
def find_independence_thresholds(roc_curves, base_rates, proportions,
264
                                 cost_matrix):
265
    """Compute thresholds that achieve independence and minimize cost.
266
267
    :param roc_curves: Receiver operating characteristic (ROC)
268
                       by attribute.
269
    :type roc_curves: dict
270
    :param base_rates: Base rate by attribute.
271
    :type base_rates: dict
272
    :param proportions: Proportion of each attribute value.
273
    :type proportions: dict
274
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
275
    :type cost_matrix: sequence
276
    :return: Thresholds, FPR and TPR by attribute and cost value.
277
    :rtype: tuple
278
279
    """
280
281
    cutoffs = {}
282
283
    def total_cost_function(acceptance_rate_value):
284
        # todo: move demo here + multiple cost
285
        #       + refactor - use threshold to calculate
286
        #         acceptance_rate_value
287
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
288
                                              acceptance_rate_value)
289
290
        total_cost = 0
291
292
        for group, roc in roc_curves.items():
293
            index = indices[group]
294
295
            fpr = roc[0][index]
296
            tpr = roc[1][index]
297
298
            group_cost = _cost_function(fpr, tpr,
299
                                        base_rates[group],
300
                                        cost_matrix)
301
302
            group_cost *= proportions[group]
303
304
            total_cost += group_cost
305
306
        return -total_cost
307
308
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
309
                                                     0, 1, TRINARY_SEARCH_TOL)
310
311
    cost = total_cost_function(acceptance_rate_min_cost)
312
313
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
314
                                                    acceptance_rate_min_cost)
315
    thresholds = _extract_threshold(roc_curves)
316
317
    cutoffs = {group: thresholds[threshold_index]
318
               for group, threshold_index
319
               in threshold_indices.items()}
320
321
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
322
                       roc[1][threshold_indices[group]])
323
               for group, roc in roc_curves.items()}
324
325
    return cutoffs, fpr_tpr, cost, acceptance_rate_min_cost
326
327
328
def get_fnr_indices(roc_curves, fnr_value):
@@ 343-403 (lines=61) @@
340
    return indices
341
342
343
def find_fnr_thresholds(roc_curves, base_rates, proportions,
344
                        cost_matrix):
345
    """Compute thresholds that achieve equal FNRs and minimize cost.
346
347
    Also known as **equal opportunity**.
348
349
    :param roc_curves: Receiver operating characteristic (ROC)
350
                       by attribute.
351
    :type roc_curves: dict
352
    :param base_rates: Base rate by attribute.
353
    :type base_rates: dict
354
    :param proportions: Proportion of each attribute value.
355
    :type proportions: dict
356
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
357
    :type cost_matrix: sequence
358
    :return: Thresholds, FPR and TPR by attribute and cost value.
359
    :rtype: tuple
360
361
    """
362
363
    cutoffs = {}
364
365
    def total_cost_function(fnr_value):
366
        # todo: move demo here + multiple cost
367
        indices = get_fnr_indices(roc_curves, fnr_value)
368
369
        total_cost = 0
370
371
        for group, roc in roc_curves.items():
372
            index = indices[group]
373
374
            fpr = roc[0][index]
375
            tpr = roc[1][index]
376
377
            group_cost = _cost_function(fpr, tpr,
378
                                        base_rates[group],
379
                                        cost_matrix)
380
            group_cost *= proportions[group]
381
382
            total_cost += group_cost
383
384
        return -total_cost
385
386
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
387
                                               0, 1,
388
                                               TRINARY_SEARCH_TOL)
389
390
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
391
392
    cost = total_cost_function(fnr_value_min_cost)
393
394
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
395
                       roc[1][threshold_indices[group]])
396
               for group, roc in roc_curves.items()}
397
398
    thresholds = _extract_threshold(roc_curves)
399
    cutoffs = {group: thresholds[threshold_index]
400
               for group, threshold_index
401
               in threshold_indices.items()}
402
403
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
404
405
406
def _find_feasible_roc(roc_curves):