Code Duplication    Length = 58-59 lines in 2 locations

ethically/fairness/interventions/threshold.py 2 locations

@@ 295-353 (lines=59) @@
292
    return indices
293
294
295
def find_fnr_thresholds(roc_curves, base_rates, proportions,
296
                        cost_matrix):
297
    """Compute thresholds that achieve equal FNRs and minimize cost.
298
299
    Also known as **equal opportunity**.
300
301
    :param roc_curves: Receiver operating characteristic (ROC)
302
                       by attribute.
303
    :type roc_curves: dict
304
    :param base_rates: Base rate by attribute.
305
    :type base_rates: dict
306
    :param proportions: Proportion of each attribute value.
307
    :type proportions: dict
308
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
309
    :type cost_matrix: sequence
310
    :return: Thresholds, FPR and TPR by attribute and cost value.
311
    :rtype: tuple
312
313
    """
314
315
    cutoffs = {}
316
317
    def total_cost_function(fnr_value):
318
        # todo: move demo here + multiple cost
319
        indices = get_fnr_indices(roc_curves, fnr_value)
320
321
        total_cost = 0
322
323
        for group, roc in roc_curves.items():
324
            index = indices[group]
325
326
            fpr = roc[0][index]
327
            tpr = roc[1][index]
328
329
            group_cost = _cost_function(fpr, tpr,
330
                                        base_rates[group],
331
                                        cost_matrix)
332
            group_cost *= proportions[group]
333
334
            total_cost += group_cost
335
336
        return -total_cost
337
338
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
339
                                               0, 1, 1e-3)
340
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
341
342
    cost = total_cost_function(fnr_value_min_cost)
343
344
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
345
                       roc[1][threshold_indices[group]])
346
               for group, roc in roc_curves.items()}
347
348
    thresholds = _extract_threshold(roc_curves)
349
    cutoffs = {group: thresholds[threshold_index]
350
               for group, threshold_index
351
               in threshold_indices.items()}
352
353
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
354
355
356
def _find_feasible_roc(roc_curves):
@@ 223-280 (lines=58) @@
220
    return indices
221
222
223
def find_independence_thresholds(roc_curves, base_rates, proportions,
224
                                 cost_matrix):
225
    """Compute thresholds that achieve independence and minimize cost.
226
227
    :param roc_curves: Receiver operating characteristic (ROC)
228
                       by attribute.
229
    :type roc_curves: dict
230
    :param base_rates: Base rate by attribute.
231
    :type base_rates: dict
232
    :param proportions: Proportion of each attribute value.
233
    :type proportions: dict
234
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
235
    :type cost_matrix: sequence
236
    :return: Thresholds, FPR and TPR by attribute and cost value.
237
    :rtype: tuple
238
239
    """
240
241
    cutoffs = {}
242
243
    def total_cost_function(acceptance_rate_value):
244
        # todo: move demo here + multiple cost
245
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
246
                                              acceptance_rate_value)
247
248
        total_cost = 0
249
250
        for group, roc in roc_curves.items():
251
            index = indices[group]
252
253
            fpr = roc[0][index]
254
            tpr = roc[1][index]
255
256
            group_cost = _cost_function(fpr, tpr,
257
                                        base_rates[group],
258
                                        cost_matrix)
259
            group_cost *= proportions[group]
260
261
            total_cost += group_cost
262
263
        return -total_cost
264
265
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
266
                                                     0, 1, 1e-3)
267
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
268
                                                    acceptance_rate_min_cost)
269
270
    thresholds = _extract_threshold(roc_curves)
271
272
    cutoffs = {group: thresholds[threshold_index]
273
               for group, threshold_index
274
               in threshold_indices.items()}
275
276
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
277
                       roc[1][threshold_indices[group]])
278
               for group, roc in roc_curves.items()}
279
280
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
281
282
283
def get_fnr_indices(roc_curves, fnr_value):