@@ 296-354 (lines=59) @@ | ||
293 | return indices |
|
294 | ||
295 | ||
296 | def find_fnr_thresholds(roc_curves, base_rates, proportions, |
|
297 | cost_matrix): |
|
298 | """Compute thresholds that achieve equal FNRs and minimize cost. |
|
299 | ||
300 | Also known as **equal opportunity**. |
|
301 | ||
302 | :param roc_curves: Receiver operating characteristic (ROC) |
|
303 | by attribute. |
|
304 | :type roc_curves: dict |
|
305 | :param base_rates: Base rate by attribute. |
|
306 | :type base_rates: dict |
|
307 | :param proportions: Proportion of each attribute value. |
|
308 | :type proportions: dict |
|
309 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
310 | :type cost_matrix: sequence |
|
311 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
312 | :rtype: tuple |
|
313 | ||
314 | """ |
|
315 | ||
316 | cutoffs = {} |
|
317 | ||
318 | def total_cost_function(fnr_value): |
|
319 | # todo: move demo here + multiple cost |
|
320 | indices = get_fnr_indices(roc_curves, fnr_value) |
|
321 | ||
322 | total_cost = 0 |
|
323 | ||
324 | for group, roc in roc_curves.items(): |
|
325 | index = indices[group] |
|
326 | ||
327 | fpr = roc[0][index] |
|
328 | tpr = roc[1][index] |
|
329 | ||
330 | group_cost = _cost_function(fpr, tpr, |
|
331 | base_rates[group], |
|
332 | cost_matrix) |
|
333 | group_cost *= proportions[group] |
|
334 | ||
335 | total_cost += group_cost |
|
336 | ||
337 | return -total_cost |
|
338 | ||
339 | fnr_value_min_cost = _ternary_search_float(total_cost_function, |
|
340 | 0, 1, 1e-3) |
|
341 | threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost) |
|
342 | ||
343 | cost = total_cost_function(fnr_value_min_cost) |
|
344 | ||
345 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
346 | roc[1][threshold_indices[group]]) |
|
347 | for group, roc in roc_curves.items()} |
|
348 | ||
349 | thresholds = _extract_threshold(roc_curves) |
|
350 | cutoffs = {group: thresholds[threshold_index] |
|
351 | for group, threshold_index |
|
352 | in threshold_indices.items()} |
|
353 | ||
354 | return cutoffs, fpr_tpr, cost, fnr_value_min_cost |
|
355 | ||
356 | ||
357 | def _find_feasible_roc(roc_curves): |
|
@@ 224-281 (lines=58) @@ | ||
221 | return indices |
|
222 | ||
223 | ||
224 | def find_independence_thresholds(roc_curves, base_rates, proportions, |
|
225 | cost_matrix): |
|
226 | """Compute thresholds that achieve independence and minimize cost. |
|
227 | ||
228 | :param roc_curves: Receiver operating characteristic (ROC) |
|
229 | by attribute. |
|
230 | :type roc_curves: dict |
|
231 | :param base_rates: Base rate by attribute. |
|
232 | :type base_rates: dict |
|
233 | :param proportions: Proportion of each attribute value. |
|
234 | :type proportions: dict |
|
235 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
236 | :type cost_matrix: sequence |
|
237 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
238 | :rtype: tuple |
|
239 | ||
240 | """ |
|
241 | ||
242 | cutoffs = {} |
|
243 | ||
244 | def total_cost_function(acceptance_rate_value): |
|
245 | # todo: move demo here + multiple cost |
|
246 | indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
247 | acceptance_rate_value) |
|
248 | ||
249 | total_cost = 0 |
|
250 | ||
251 | for group, roc in roc_curves.items(): |
|
252 | index = indices[group] |
|
253 | ||
254 | fpr = roc[0][index] |
|
255 | tpr = roc[1][index] |
|
256 | ||
257 | group_cost = _cost_function(fpr, tpr, |
|
258 | base_rates[group], |
|
259 | cost_matrix) |
|
260 | group_cost *= proportions[group] |
|
261 | ||
262 | total_cost += group_cost |
|
263 | ||
264 | return -total_cost |
|
265 | ||
266 | acceptance_rate_min_cost = _ternary_search_float(total_cost_function, |
|
267 | 0, 1, 1e-3) |
|
268 | threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
269 | acceptance_rate_min_cost) |
|
270 | ||
271 | thresholds = _extract_threshold(roc_curves) |
|
272 | ||
273 | cutoffs = {group: thresholds[threshold_index] |
|
274 | for group, threshold_index |
|
275 | in threshold_indices.items()} |
|
276 | ||
277 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
278 | roc[1][threshold_indices[group]]) |
|
279 | for group, roc in roc_curves.items()} |
|
280 | ||
281 | return cutoffs, fpr_tpr, acceptance_rate_min_cost |
|
282 | ||
283 | ||
284 | def get_fnr_indices(roc_curves, fnr_value): |