|
1
|
|
|
import copy |
|
2
|
|
|
|
|
3
|
|
|
from .core import BiasWordsEmbedding |
|
4
|
|
|
from .data import BOLUKBASI_DATA |
|
5
|
|
|
from .utils import generate_one_word_forms, generate_words_forms |
|
6
|
|
|
|
|
7
|
|
|
|
|
8
|
|
|
class GenderBiasWE(BiasWordsEmbedding): |
|
9
|
|
|
def __init__(self, model, only_lower=True, verbose=False): |
|
10
|
|
|
super().__init__(model, only_lower, verbose) |
|
11
|
|
|
self._initialize_data() |
|
12
|
|
|
self._identify_direction('he', 'she', |
|
13
|
|
|
self._data['definitional_pairs'], |
|
14
|
|
|
'pca') |
|
15
|
|
|
|
|
16
|
|
|
def _initialize_data(self): |
|
17
|
|
|
self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) |
|
18
|
|
|
|
|
19
|
|
|
if not self.only_lower: |
|
20
|
|
|
self._data['specific_full_with_definitional'] = \ |
|
21
|
|
|
generate_words_forms(self |
|
22
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
|
23
|
|
|
|
|
24
|
|
|
for key in self._data['word_group_keys']: |
|
25
|
|
|
self._data[key] = (self._filter_words_by_model(self |
|
26
|
|
|
._data[key])) |
|
27
|
|
|
|
|
28
|
|
|
self._data['neutral_words'] = self._extract_neutral_words(self |
|
29
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
|
30
|
|
|
self._data['word_group_keys'].append('neutral_words') |
|
31
|
|
|
|
|
32
|
|
|
def plot_dist_projections_on_direction(self, word_groups='bolukbasi', |
|
33
|
|
|
ax=None): |
|
34
|
|
|
if word_groups == 'bolukbasi': |
|
35
|
|
|
word_groups = {key: self._data[key] |
|
36
|
|
|
for key in self._data['word_group_keys']} |
|
37
|
|
|
|
|
38
|
|
|
super().plot_dist_projections_on_direction(word_groups, ax) |
|
39
|
|
|
|
|
40
|
|
|
def calc_direct_bias(self, neutral_words='professions', c=None): |
|
41
|
|
|
if isinstance(neutral_words, str) and neutral_words == 'professions': |
|
42
|
|
|
return super().calc_direct_bias( |
|
43
|
|
|
self._data['neutral_profession_names'], c) |
|
44
|
|
|
else: |
|
45
|
|
|
return super().calc_direct_bias(neutral_words) |
|
46
|
|
|
|
|
47
|
|
|
def debias(self, method='hard', neutral_words=None, equality_sets=None, |
|
48
|
|
|
inplace=True): |
|
49
|
|
|
# pylint: disable=C0301 |
|
50
|
|
|
if method in ['hard', 'neutralize']: |
|
51
|
|
|
if neutral_words is None: |
|
52
|
|
|
neutral_words = self._data['neutral_words'] |
|
53
|
|
|
|
|
54
|
|
|
if method == 'hard' and equality_sets is None: |
|
55
|
|
|
equality_sets = self._data['definitional_pairs'] |
|
56
|
|
|
|
|
57
|
|
|
if not self.only_lower: |
|
58
|
|
|
assert all(len(equality_set) == 2 |
|
59
|
|
|
for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' |
|
60
|
|
|
# TODO: refactor |
|
61
|
|
|
equality_sets = {(candidate1, candidate2) |
|
62
|
|
|
for word1, word2 in equality_sets |
|
63
|
|
|
for candidate1, candidate2 in zip(generate_one_word_forms(word1), |
|
64
|
|
|
generate_one_word_forms(word2))} |
|
65
|
|
|
|
|
66
|
|
|
return super().debias(method, neutral_words, equality_sets, |
|
67
|
|
|
inplace) |
|
68
|
|
|
|
|
69
|
|
|
def learn_full_specific_words(self, seed_specific_words='bolukbasi', |
|
70
|
|
|
max_non_specific_examples=None, |
|
71
|
|
|
debug=None): |
|
72
|
|
|
if seed_specific_words == 'bolukbasi': |
|
73
|
|
|
seed_specific_words = self._data['specific_seed'] |
|
74
|
|
|
|
|
75
|
|
|
return super().learn_full_specific_words(seed_specific_words, |
|
76
|
|
|
max_non_specific_examples, |
|
77
|
|
|
debug) |
|
78
|
|
|
|