Code Duplication    Length = 58-59 lines in 2 locations

ethically/fairness/interventions/threshold.py 2 locations

@@ 296-354 (lines=59) @@
293
    return indices
294
295
296
def find_fnr_thresholds(roc_curves, base_rates, proportions,
297
                        cost_matrix):
298
    """Compute thresholds that achieve equal FNRs and minimize cost.
299
300
    Also known as **equal opportunity**.
301
302
    :param roc_curves: Receiver operating characteristic (ROC)
303
                       by attribute.
304
    :type roc_curves: dict
305
    :param base_rates: Base rate by attribute.
306
    :type base_rates: dict
307
    :param proportions: Proportion of each attribute value.
308
    :type proportions: dict
309
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
310
    :type cost_matrix: sequence
311
    :return: Thresholds, FPR and TPR by attribute and cost value.
312
    :rtype: tuple
313
314
    """
315
316
    cutoffs = {}
317
318
    def total_cost_function(fnr_value):
319
        # todo: move demo here + multiple cost
320
        indices = get_fnr_indices(roc_curves, fnr_value)
321
322
        total_cost = 0
323
324
        for group, roc in roc_curves.items():
325
            index = indices[group]
326
327
            fpr = roc[0][index]
328
            tpr = roc[1][index]
329
330
            group_cost = _cost_function(fpr, tpr,
331
                                        base_rates[group],
332
                                        cost_matrix)
333
            group_cost *= proportions[group]
334
335
            total_cost += group_cost
336
337
        return -total_cost
338
339
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
340
                                               0, 1, 1e-3)
341
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
342
343
    cost = total_cost_function(fnr_value_min_cost)
344
345
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
346
                       roc[1][threshold_indices[group]])
347
               for group, roc in roc_curves.items()}
348
349
    thresholds = _extract_threshold(roc_curves)
350
    cutoffs = {group: thresholds[threshold_index]
351
               for group, threshold_index
352
               in threshold_indices.items()}
353
354
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
355
356
357
def _find_feasible_roc(roc_curves):
@@ 224-281 (lines=58) @@
221
    return indices
222
223
224
def find_independence_thresholds(roc_curves, base_rates, proportions,
225
                                 cost_matrix):
226
    """Compute thresholds that achieve independence and minimize cost.
227
228
    :param roc_curves: Receiver operating characteristic (ROC)
229
                       by attribute.
230
    :type roc_curves: dict
231
    :param base_rates: Base rate by attribute.
232
    :type base_rates: dict
233
    :param proportions: Proportion of each attribute value.
234
    :type proportions: dict
235
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
236
    :type cost_matrix: sequence
237
    :return: Thresholds, FPR and TPR by attribute and cost value.
238
    :rtype: tuple
239
240
    """
241
242
    cutoffs = {}
243
244
    def total_cost_function(acceptance_rate_value):
245
        # todo: move demo here + multiple cost
246
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
247
                                              acceptance_rate_value)
248
249
        total_cost = 0
250
251
        for group, roc in roc_curves.items():
252
            index = indices[group]
253
254
            fpr = roc[0][index]
255
            tpr = roc[1][index]
256
257
            group_cost = _cost_function(fpr, tpr,
258
                                        base_rates[group],
259
                                        cost_matrix)
260
            group_cost *= proportions[group]
261
262
            total_cost += group_cost
263
264
        return -total_cost
265
266
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
267
                                                     0, 1, 1e-3)
268
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
269
                                                    acceptance_rate_min_cost)
270
271
    thresholds = _extract_threshold(roc_curves)
272
273
    cutoffs = {group: thresholds[threshold_index]
274
               for group, threshold_index
275
               in threshold_indices.items()}
276
277
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
278
                       roc[1][threshold_indices[group]])
279
               for group, roc in roc_curves.items()}
280
281
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
282
283
284
def get_fnr_indices(roc_curves, fnr_value):