1
|
|
|
import copy |
2
|
|
|
|
3
|
|
|
from .core import BiasWordsEmbedding |
4
|
|
|
from .data import BOLUKBASI_DATA |
5
|
|
|
from .utils import generate_one_word_forms, generate_words_forms |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class GenderBiasWE(BiasWordsEmbedding): |
9
|
|
|
def __init__(self, model, only_lower=True, verbose=False): |
10
|
|
|
super().__init__(model, only_lower, verbose) |
11
|
|
|
self._initialize_data() |
12
|
|
|
self._identify_direction('he', 'she', |
13
|
|
|
self._data['definitional_pairs'], |
14
|
|
|
'pca') |
15
|
|
|
|
16
|
|
|
def _initialize_data(self): |
17
|
|
|
self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) |
18
|
|
|
|
19
|
|
|
if not self.only_lower: |
20
|
|
|
self._data['specific_full_with_definitional'] = \ |
21
|
|
|
generate_words_forms(self |
22
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
23
|
|
|
|
24
|
|
|
for key in self._data['word_group_keys']: |
25
|
|
|
self._data[key] = (self._filter_words_by_model(self |
26
|
|
|
._data[key])) |
27
|
|
|
|
28
|
|
|
self._data['neutral_words'] = self._extract_neutral_words(self |
29
|
|
|
._data['specific_full_with_definitional']) # pylint: disable=C0301 |
30
|
|
|
self._data['word_group_keys'].append('neutral_words') |
31
|
|
|
|
32
|
|
|
def plot_dist_projections_on_direction(self, word_groups='bolukbasi', |
33
|
|
|
ax=None): |
34
|
|
|
if word_groups == 'bolukbasi': |
35
|
|
|
word_groups = {key: self._data[key] |
36
|
|
|
for key in self._data['word_group_keys']} |
37
|
|
|
|
38
|
|
|
super().plot_dist_projections_on_direction(word_groups, ax) |
39
|
|
|
|
40
|
|
|
def calc_direct_bias(self, neutral_words='professions', c=None): |
41
|
|
|
if isinstance(neutral_words, str) and neutral_words == 'professions': |
42
|
|
|
return super().calc_direct_bias( |
43
|
|
|
self._data['neutral_profession_names'], c) |
44
|
|
|
else: |
45
|
|
|
return super().calc_direct_bias(neutral_words) |
46
|
|
|
|
47
|
|
|
def debias(self, method='hard', neutral_words=None, equality_sets=None, |
48
|
|
|
inplace=True): |
49
|
|
|
# pylint: disable=C0301 |
50
|
|
|
if method in ['hard', 'neutralize']: |
51
|
|
|
if neutral_words is None: |
52
|
|
|
neutral_words = self._data['neutral_words'] |
53
|
|
|
|
54
|
|
|
if method == 'hard' and equality_sets is None: |
55
|
|
|
equality_sets = self._data['definitional_pairs'] |
56
|
|
|
|
57
|
|
|
if not self.only_lower: |
58
|
|
|
assert all(len(equality_set) == 2 |
59
|
|
|
for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' |
60
|
|
|
# TODO: refactor |
61
|
|
|
equality_sets = {(candidate1, candidate2) |
62
|
|
|
for word1, word2 in equality_sets |
63
|
|
|
for candidate1, candidate2 in zip(generate_one_word_forms(word1), |
64
|
|
|
generate_one_word_forms(word2))} |
65
|
|
|
|
66
|
|
|
return super().debias(method, neutral_words, equality_sets, |
67
|
|
|
inplace) |
68
|
|
|
|
69
|
|
|
def learn_full_specific_words(self, seed_specific_words='bolukbasi', |
70
|
|
|
max_non_specific_examples=None, |
71
|
|
|
debug=None): |
72
|
|
|
if seed_specific_words == 'bolukbasi': |
73
|
|
|
seed_specific_words = self._data['specific_seed'] |
74
|
|
|
|
75
|
|
|
return super().learn_full_specific_words(seed_specific_words, |
76
|
|
|
max_non_specific_examples, |
77
|
|
|
debug) |
78
|
|
|
|