@@ 295-353 (lines=59) @@ | ||
292 | return indices |
|
293 | ||
294 | ||
295 | def find_fnr_thresholds(roc_curves, base_rates, proportions, |
|
296 | cost_matrix): |
|
297 | """Compute thresholds that achieve equal FNRs and minimize cost. |
|
298 | ||
299 | Also known as **equal opportunity**. |
|
300 | ||
301 | :param roc_curves: Receiver operating characteristic (ROC) |
|
302 | by attribute. |
|
303 | :type roc_curves: dict |
|
304 | :param base_rates: Base rate by attribute. |
|
305 | :type base_rates: dict |
|
306 | :param proportions: Proportion of each attribute value. |
|
307 | :type proportions: dict |
|
308 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
309 | :type cost_matrix: sequence |
|
310 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
311 | :rtype: tuple |
|
312 | ||
313 | """ |
|
314 | ||
315 | cutoffs = {} |
|
316 | ||
317 | def total_cost_function(fnr_value): |
|
318 | # todo: move demo here + multiple cost |
|
319 | indices = get_fnr_indices(roc_curves, fnr_value) |
|
320 | ||
321 | total_cost = 0 |
|
322 | ||
323 | for group, roc in roc_curves.items(): |
|
324 | index = indices[group] |
|
325 | ||
326 | fpr = roc[0][index] |
|
327 | tpr = roc[1][index] |
|
328 | ||
329 | group_cost = _cost_function(fpr, tpr, |
|
330 | base_rates[group], |
|
331 | cost_matrix) |
|
332 | group_cost *= proportions[group] |
|
333 | ||
334 | total_cost += group_cost |
|
335 | ||
336 | return -total_cost |
|
337 | ||
338 | fnr_value_min_cost = _ternary_search_float(total_cost_function, |
|
339 | 0, 1, 1e-3) |
|
340 | threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost) |
|
341 | ||
342 | cost = total_cost_function(fnr_value_min_cost) |
|
343 | ||
344 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
345 | roc[1][threshold_indices[group]]) |
|
346 | for group, roc in roc_curves.items()} |
|
347 | ||
348 | thresholds = _extract_threshold(roc_curves) |
|
349 | cutoffs = {group: thresholds[threshold_index] |
|
350 | for group, threshold_index |
|
351 | in threshold_indices.items()} |
|
352 | ||
353 | return cutoffs, fpr_tpr, cost, fnr_value_min_cost |
|
354 | ||
355 | ||
356 | def _find_feasible_roc(roc_curves): |
|
@@ 223-280 (lines=58) @@ | ||
220 | return indices |
|
221 | ||
222 | ||
223 | def find_independence_thresholds(roc_curves, base_rates, proportions, |
|
224 | cost_matrix): |
|
225 | """Compute thresholds that achieve independence and minimize cost. |
|
226 | ||
227 | :param roc_curves: Receiver operating characteristic (ROC) |
|
228 | by attribute. |
|
229 | :type roc_curves: dict |
|
230 | :param base_rates: Base rate by attribute. |
|
231 | :type base_rates: dict |
|
232 | :param proportions: Proportion of each attribute value. |
|
233 | :type proportions: dict |
|
234 | :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]]. |
|
235 | :type cost_matrix: sequence |
|
236 | :return: Thresholds, FPR and TPR by attribute and cost value. |
|
237 | :rtype: tuple |
|
238 | ||
239 | """ |
|
240 | ||
241 | cutoffs = {} |
|
242 | ||
243 | def total_cost_function(acceptance_rate_value): |
|
244 | # todo: move demo here + multiple cost |
|
245 | indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
246 | acceptance_rate_value) |
|
247 | ||
248 | total_cost = 0 |
|
249 | ||
250 | for group, roc in roc_curves.items(): |
|
251 | index = indices[group] |
|
252 | ||
253 | fpr = roc[0][index] |
|
254 | tpr = roc[1][index] |
|
255 | ||
256 | group_cost = _cost_function(fpr, tpr, |
|
257 | base_rates[group], |
|
258 | cost_matrix) |
|
259 | group_cost *= proportions[group] |
|
260 | ||
261 | total_cost += group_cost |
|
262 | ||
263 | return -total_cost |
|
264 | ||
265 | acceptance_rate_min_cost = _ternary_search_float(total_cost_function, |
|
266 | 0, 1, 1e-3) |
|
267 | threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates, |
|
268 | acceptance_rate_min_cost) |
|
269 | ||
270 | thresholds = _extract_threshold(roc_curves) |
|
271 | ||
272 | cutoffs = {group: thresholds[threshold_index] |
|
273 | for group, threshold_index |
|
274 | in threshold_indices.items()} |
|
275 | ||
276 | fpr_tpr = {group: (roc[0][threshold_indices[group]], |
|
277 | roc[1][threshold_indices[group]]) |
|
278 | for group, roc in roc_curves.items()} |
|
279 | ||
280 | return cutoffs, fpr_tpr, acceptance_rate_min_cost |
|
281 | ||
282 | ||
283 | def get_fnr_indices(roc_curves, fnr_value): |