1
|
|
|
import copy |
2
|
|
|
|
3
|
|
|
from .core import BiasWordsEmbedding |
4
|
|
|
from .data import BOLUKBASI_DATA |
5
|
|
|
from .utils import generate_one_word_forms, generate_words_forms |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class GenderBiasWE(BiasWordsEmbedding): |
9
|
|
|
"""Audit and Adjust the Gender Bias in English Words Embedding. |
10
|
|
|
|
11
|
|
|
:param model: Words embedding model of ``gensim.model.KeyedVectors`` |
12
|
|
|
:param bool only_lower: Whether the words embedding contrains |
13
|
|
|
only lower case words |
14
|
|
|
:param bool verbose: Set vebosity |
15
|
|
|
""" |
16
|
|
|
|
17
|
|
|
def __init__(self, model, only_lower=False, verbose=False, |
18
|
|
|
identify_direction=True): |
19
|
|
|
super().__init__(model, only_lower, verbose) |
20
|
|
|
self._initialize_data() |
21
|
|
|
if identify_direction: |
22
|
|
|
self._identify_direction('she', 'he', |
23
|
|
|
self._data['definitional_pairs'], |
24
|
|
|
'pca') |
25
|
|
|
|
26
|
|
|
def _initialize_data(self): |
27
|
|
|
self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) |
28
|
|
|
|
29
|
|
|
if not self.only_lower: |
30
|
|
|
self._data['specific_full_with_definitional'] = \ |
31
|
|
|
generate_words_forms(self |
32
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
33
|
|
|
|
34
|
|
|
for key in self._data['word_group_keys']: |
35
|
|
|
self._data[key] = (self._filter_words_by_model(self |
36
|
|
|
._data[key])) |
37
|
|
|
|
38
|
|
|
self._data['neutral_words'] = self._extract_neutral_words(self |
39
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
40
|
|
|
self._data['neutral_words'].sort() |
41
|
|
|
self._data['word_group_keys'].append('neutral_words') |
42
|
|
|
|
43
|
|
|
def plot_projection_scores(self, words='professions', n_extreme=10, |
44
|
|
|
ax=None, axis_projection_step=None): |
45
|
|
|
if words == 'professions': |
46
|
|
|
words = self._data['profession_names'] |
47
|
|
|
|
48
|
|
|
return super().plot_projection_scores(words, n_extreme, |
49
|
|
|
ax, axis_projection_step) |
50
|
|
|
|
51
|
|
|
def plot_dist_projections_on_direction(self, word_groups='bolukbasi', |
52
|
|
|
ax=None): |
53
|
|
|
if word_groups == 'bolukbasi': |
54
|
|
|
word_groups = {key: self._data[key] |
55
|
|
|
for key in self._data['word_group_keys']} |
56
|
|
|
|
57
|
|
|
return super().plot_dist_projections_on_direction(word_groups, ax) |
58
|
|
|
|
59
|
|
|
@classmethod |
60
|
|
|
def plot_bias_across_words_embeddings(cls, words_embedding_bias_dict, |
61
|
|
|
ax=None, scatter_kwargs=None): |
62
|
|
|
# pylint: disable=W0221 |
63
|
|
|
words = BOLUKBASI_DATA['gender']['neutral_profession_names'] |
64
|
|
|
# TODO: is it correct for inhertence of class method? |
65
|
|
|
super(cls, cls).plot_bias_across_words_embeddings(words_embedding_bias_dict, # pylint: disable=C0301 |
66
|
|
|
words, |
67
|
|
|
ax, |
68
|
|
|
scatter_kwargs) |
69
|
|
|
|
70
|
|
|
def calc_direct_bias(self, neutral_words='professions', c=None): |
71
|
|
|
if isinstance(neutral_words, str) and neutral_words == 'professions': |
72
|
|
|
return super().calc_direct_bias( |
73
|
|
|
self._data['neutral_profession_names'], c) |
74
|
|
|
else: |
75
|
|
|
return super().calc_direct_bias(neutral_words) |
76
|
|
|
|
77
|
|
|
def generate_closest_words_indirect_bias(self, |
78
|
|
|
neutral_positive_end, |
79
|
|
|
neutral_negative_end, |
80
|
|
|
words='professions', n_extreme=5): |
81
|
|
|
# pylint: disable=C0301 |
82
|
|
|
|
83
|
|
|
if words == 'professions': |
84
|
|
|
words = self._data['profession_names'] |
85
|
|
|
|
86
|
|
|
return super().generate_closest_words_indirect_bias(neutral_positive_end, |
87
|
|
|
neutral_negative_end, |
88
|
|
|
words, |
89
|
|
|
n_extreme=n_extreme) |
90
|
|
|
|
91
|
|
|
def debias(self, method='hard', neutral_words=None, equality_sets=None, |
92
|
|
|
inplace=True): |
93
|
|
|
# pylint: disable=C0301 |
94
|
|
|
if method in ['hard', 'neutralize']: |
95
|
|
|
if neutral_words is None: |
96
|
|
|
neutral_words = self._data['neutral_words'] |
97
|
|
|
|
98
|
|
|
if method == 'hard' and equality_sets is None: |
99
|
|
|
equality_sets = self._data['definitional_pairs'] |
100
|
|
|
|
101
|
|
|
if not self.only_lower: |
102
|
|
|
assert all(len(equality_set) == 2 |
103
|
|
|
for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' |
104
|
|
|
# TODO: refactor |
105
|
|
|
equality_sets = {(candidate1, candidate2) |
106
|
|
|
for word1, word2 in equality_sets |
107
|
|
|
for candidate1, candidate2 in zip(generate_one_word_forms(word1), |
108
|
|
|
generate_one_word_forms(word2))} |
109
|
|
|
|
110
|
|
|
return super().debias(method, neutral_words, equality_sets, |
111
|
|
|
inplace) |
112
|
|
|
|
113
|
|
|
def learn_full_specific_words(self, seed_specific_words='bolukbasi', |
114
|
|
|
max_non_specific_examples=None, |
115
|
|
|
debug=None): |
116
|
|
|
if seed_specific_words == 'bolukbasi': |
117
|
|
|
seed_specific_words = self._data['specific_seed'] |
118
|
|
|
|
119
|
|
|
return super().learn_full_specific_words(seed_specific_words, |
120
|
|
|
max_non_specific_examples, |
121
|
|
|
debug) |
122
|
|
|
|