1
|
|
|
import copy |
2
|
|
|
|
3
|
|
|
from .core import BiasWordsEmbedding |
4
|
|
|
from .data import BOLUKBASI_DATA |
5
|
|
|
from .utils import generate_one_word_forms, generate_words_forms |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class GenderBiasWE(BiasWordsEmbedding): |
9
|
|
|
def __init__(self, model, only_lower=True): |
10
|
|
|
super().__init__(model, only_lower) |
11
|
|
|
self._initialize_data() |
12
|
|
|
self._identify_direction('he', 'she', |
13
|
|
|
self._data['definitional_pairs'], |
14
|
|
|
'pca') |
15
|
|
|
|
16
|
|
|
def _initialize_data(self): |
17
|
|
|
self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) |
18
|
|
|
|
19
|
|
|
if not self.only_lower: |
20
|
|
|
self._data['specific_full_with_definitional'] = \ |
21
|
|
|
generate_words_forms(self |
22
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
23
|
|
|
|
24
|
|
|
for key in self._data: |
25
|
|
|
# if self._data[key] is list of strings |
26
|
|
|
if (isinstance(self._data[key], list) |
27
|
|
|
and self._data[key] |
28
|
|
|
and isinstance(self._data[key][0], str)): |
29
|
|
|
|
30
|
|
|
self._data[key] = (self._filter_words_by_model(self |
31
|
|
|
._data[key])) |
32
|
|
|
|
33
|
|
|
self._data['neutral_words'] = self._extract_neutral_words(self |
34
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
35
|
|
|
|
36
|
|
|
def calc_direct_bias(self, neutral_words='professions', c=None): |
37
|
|
|
if isinstance(neutral_words, str) and neutral_words == 'professions': |
38
|
|
|
return super().calc_direct_bias( |
39
|
|
|
self._data['neutral_profession_names'], c) |
40
|
|
|
else: |
41
|
|
|
return super().calc_direct_bias(neutral_words) |
42
|
|
|
|
43
|
|
|
def debias(self, method='hard', neutral_words=None, equality_sets=None, |
44
|
|
|
inplace=True, verbose=False): |
45
|
|
|
# pylint: disable=C0301 |
46
|
|
|
if method in ['hard', 'neutralize']: |
47
|
|
|
if neutral_words is None: |
48
|
|
|
neutral_words = self._data['neutral_words'] |
49
|
|
|
|
50
|
|
|
if method == 'hard' and equality_sets is None: |
51
|
|
|
equality_sets = self._data['definitional_pairs'] |
52
|
|
|
|
53
|
|
|
if not self.only_lower: |
54
|
|
|
assert all(len(equality_set) == 2 |
55
|
|
|
for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' |
56
|
|
|
# TODO: refactor |
57
|
|
|
equality_sets = {(candidate1, candidate2) |
58
|
|
|
for word1, word2 in equality_sets |
59
|
|
|
for candidate1, candidate2 in zip(generate_one_word_forms(word1), |
60
|
|
|
generate_one_word_forms(word2))} |
61
|
|
|
|
62
|
|
|
return super().debias(method, neutral_words, equality_sets, |
63
|
|
|
inplace, verbose) |
64
|
|
|
|
65
|
|
|
def learn_full_specific_words(self, seed_specific_words='bolukbasi', |
66
|
|
|
max_non_specific_examples=None, |
67
|
|
|
debug=None): |
68
|
|
|
if seed_specific_words == 'bolukbasi': |
69
|
|
|
seed_specific_words = self._data['specific_seed'] |
70
|
|
|
|
71
|
|
|
return super().learn_full_specific_words(seed_specific_words, |
72
|
|
|
max_non_specific_examples, |
73
|
|
|
debug) |
74
|
|
|
|