Passed
Push — master ( fe038f...170db5 )
by Shlomi
03:29 queued 01:43
created

ethically.we.bias   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 122
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 87
dl 0
loc 122
rs 10
c 0
b 0
f 0
wmc 23

9 Methods

Rating   Name   Duplication   Size   Complexity  
A GenderBiasWE.generate_closest_words_indirect_bias() 0 13 2
A GenderBiasWE.calc_direct_bias() 0 6 3
A GenderBiasWE._initialize_data() 0 16 3
B GenderBiasWE.debias() 0 21 6
A GenderBiasWE.plot_projection_scores() 0 7 2
A GenderBiasWE.plot_dist_projections_on_direction() 0 7 2
A GenderBiasWE.learn_full_specific_words() 0 9 2
A GenderBiasWE.plot_bias_across_words_embeddings() 0 10 1
A GenderBiasWE.__init__() 0 8 2
1
import copy
2
3
from .core import BiasWordsEmbedding
4
from .data import BOLUKBASI_DATA
5
from .utils import generate_one_word_forms, generate_words_forms
6
7
8
class GenderBiasWE(BiasWordsEmbedding):
9
    """Audit and Adjust the Gender Bias in English Words Embedding.
10
11
    :param model: Words embedding model of ``gensim.model.KeyedVectors``
12
    :param bool only_lower: Whether the words embedding contrains
13
                            only lower case words
14
    :param bool verbose: Set vebosity
15
    """
16
17
    def __init__(self, model, only_lower=False, verbose=False,
18
                 identify_direction=True):
19
        super().__init__(model, only_lower, verbose)
20
        self._initialize_data()
21
        if identify_direction:
22
            self._identify_direction('she', 'he',
23
                                     self._data['definitional_pairs'],
24
                                     'pca')
25
26
    def _initialize_data(self):
27
        self._data = copy.deepcopy(BOLUKBASI_DATA['gender'])
28
29
        if not self.only_lower:
30
            self._data['specific_full_with_definitional'] = \
31
                generate_words_forms(self
32
                                     ._data['specific_full_with_definitional'])  # pylint: disable=C0301
33
34
        for key in self._data['word_group_keys']:
35
            self._data[key] = (self._filter_words_by_model(self
36
                                                           ._data[key]))
37
38
        self._data['neutral_words'] = self._extract_neutral_words(self
39
                                                                  ._data['specific_full_with_definitional'])  # pylint: disable=C0301
40
        self._data['neutral_words'].sort()
41
        self._data['word_group_keys'].append('neutral_words')
42
43
    def plot_projection_scores(self, words='professions', n_extreme=10,
44
                               ax=None, axis_projection_step=None):
45
        if words == 'professions':
46
            words = self._data['profession_names']
47
48
        return super().plot_projection_scores(words, n_extreme,
49
                                              ax, axis_projection_step)
50
51
    def plot_dist_projections_on_direction(self, word_groups='bolukbasi',
52
                                           ax=None):
53
        if word_groups == 'bolukbasi':
54
            word_groups = {key: self._data[key]
55
                           for key in self._data['word_group_keys']}
56
57
        return super().plot_dist_projections_on_direction(word_groups, ax)
58
59
    @classmethod
60
    def plot_bias_across_words_embeddings(cls, words_embedding_bias_dict,
61
                                          ax=None, scatter_kwargs=None):
62
        # pylint: disable=W0221
63
        words = BOLUKBASI_DATA['gender']['neutral_profession_names']
64
        # TODO: is it correct for inhertence of class method?
65
        super(cls, cls).plot_bias_across_words_embeddings(words_embedding_bias_dict,  # pylint: disable=C0301
66
                                                          words,
67
                                                          ax,
68
                                                          scatter_kwargs)
69
70
    def calc_direct_bias(self, neutral_words='professions', c=None):
71
        if isinstance(neutral_words, str) and neutral_words == 'professions':
72
            return super().calc_direct_bias(
73
                self._data['neutral_profession_names'], c)
74
        else:
75
            return super().calc_direct_bias(neutral_words)
76
77
    def generate_closest_words_indirect_bias(self,
78
                                             neutral_positive_end,
79
                                             neutral_negative_end,
80
                                             words='professions', n_extreme=5):
81
        # pylint: disable=C0301
82
83
        if words == 'professions':
84
            words = self._data['profession_names']
85
86
        return super().generate_closest_words_indirect_bias(neutral_positive_end,
87
                                                            neutral_negative_end,
88
                                                            words,
89
                                                            n_extreme=n_extreme)
90
91
    def debias(self, method='hard', neutral_words=None, equality_sets=None,
92
               inplace=True):
93
        # pylint: disable=C0301
94
        if method in ['hard', 'neutralize']:
95
            if neutral_words is None:
96
                neutral_words = self._data['neutral_words']
97
98
        if method == 'hard' and equality_sets is None:
99
            equality_sets = self._data['definitional_pairs']
100
101
            if not self.only_lower:
102
                assert all(len(equality_set) == 2
103
                           for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False'
104
                # TODO: refactor
105
                equality_sets = {(candidate1, candidate2)
106
                                 for word1, word2 in equality_sets
107
                                 for candidate1, candidate2 in zip(generate_one_word_forms(word1),
108
                                                                   generate_one_word_forms(word2))}
109
110
        return super().debias(method, neutral_words, equality_sets,
111
                              inplace)
112
113
    def learn_full_specific_words(self, seed_specific_words='bolukbasi',
114
                                  max_non_specific_examples=None,
115
                                  debug=None):
116
        if seed_specific_words == 'bolukbasi':
117
            seed_specific_words = self._data['specific_seed']
118
119
        return super().learn_full_specific_words(seed_specific_words,
120
                                                 max_non_specific_examples,
121
                                                 debug)
122