Passed
Push — master ( a15aca...2d6850 )
by Shlomi
01:44
created

ethically.we.bias.GenderBiasWE.debias()   B

Complexity

Conditions 6

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 16
nop 6
dl 0
loc 21
rs 8.6666
c 0
b 0
f 0
1
import copy
2
3
from .core import BiasWordsEmbedding
4
from .data import BOLUKBASI_DATA
5
from .utils import generate_one_word_forms, generate_words_forms
6
7
8
class GenderBiasWE(BiasWordsEmbedding):
9
    def __init__(self, model, only_lower=True):
10
        super().__init__(model, only_lower)
11
        self._initialize_data()
12
        self._identify_direction('he', 'she',
13
                                 self._data['definitional_pairs'],
14
                                 'pca')
15
16
    def _initialize_data(self):
17
        self._data = copy.deepcopy(BOLUKBASI_DATA['gender'])
18
19
        if not self.only_lower:
20
            self._data['specific_full_with_definitional'] = \
21
                generate_words_forms(self
22
                                     ._data['specific_full_with_definitional'])  # pylint: disable=C0301
23
24
        for key in self._data:
25
            # if self._data[key] is list of strings
26
            if (isinstance(self._data[key], list)
27
                    and self._data[key]
28
                    and isinstance(self._data[key][0], str)):
29
30
                self._data[key] = (self._filter_words_by_model(self
31
                                                               ._data[key]))
32
33
        self._data['neutral_words'] = self._extract_neutral_words(self
34
                                                                  ._data['specific_full_with_definitional'])  # pylint: disable=C0301
35
36
    def calc_direct_bias(self, neutral_words='professions', c=None):
37
        if isinstance(neutral_words, str) and neutral_words == 'professions':
38
            return super().calc_direct_bias(
39
                self._data['neutral_profession_names'], c)
40
        else:
41
            return super().calc_direct_bias(neutral_words)
42
43
    def debias(self, method='hard', neutral_words=None, equality_sets=None,
44
               inplace=True, verbose=False):
45
        # pylint: disable=C0301
46
        if method in ['hard', 'neutralize']:
47
            if neutral_words is None:
48
                neutral_words = self._data['neutral_words']
49
50
        if method == 'hard' and equality_sets is None:
51
            equality_sets = self._data['definitional_pairs']
52
53
            if not self.only_lower:
54
                assert all(len(equality_set) == 2
55
                           for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False'
56
                # TODO: refactor
57
                equality_sets = {(candidate1, candidate2)
58
                                 for word1, word2 in equality_sets
59
                                 for candidate1, candidate2 in zip(generate_one_word_forms(word1),
60
                                                                   generate_one_word_forms(word2))}
61
62
        return super().debias(method, neutral_words, equality_sets,
63
                              inplace, verbose)
64
65
    def learn_full_specific_words(self, seed_specific_words='bolukbasi',
66
                                  max_non_specific_examples=None,
67
                                  debug=None):
68
        if seed_specific_words == 'bolukbasi':
69
            seed_specific_words = self._data['specific_seed']
70
71
        return super().learn_full_specific_words(seed_specific_words,
72
                                                 max_non_specific_examples,
73
                                                 debug)
74