Passed
Push — master ( 999448...189cd4 )
by Shlomi
01:46
created

GenderBiasWE.plot_dist_projections_on_direction()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 3
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import copy
2
3
from .core import BiasWordsEmbedding
4
from .data import BOLUKBASI_DATA
5
from .utils import generate_one_word_forms, generate_words_forms
6
7
8
class GenderBiasWE(BiasWordsEmbedding):
9
    def __init__(self, model, only_lower=True, verbose=False):
10
        super().__init__(model, only_lower, verbose)
11
        self._initialize_data()
12
        self._identify_direction('he', 'she',
13
                                 self._data['definitional_pairs'],
14
                                 'pca')
15
16
    def _initialize_data(self):
17
        self._data = copy.deepcopy(BOLUKBASI_DATA['gender'])
18
19
        if not self.only_lower:
20
            self._data['specific_full_with_definitional'] = \
21
                generate_words_forms(self
22
                                     ._data['specific_full_with_definitional'])  # pylint: disable=C0301
23
24
        for key in self._data['word_group_keys']:
25
            self._data[key] = (self._filter_words_by_model(self
26
                                                           ._data[key]))
27
28
        self._data['neutral_words'] = self._extract_neutral_words(self
29
                                                                  ._data['specific_full_with_definitional'])  # pylint: disable=C0301
30
        self._data['word_group_keys'].append('neutral_words')
31
32
    def plot_dist_projections_on_direction(self, word_groups='bolukbasi',
33
                                           ax=None):
34
        if word_groups == 'bolukbasi':
35
            word_groups = {key: self._data[key]
36
                           for key in self._data['word_group_keys']}
37
38
        super().plot_dist_projections_on_direction(word_groups, ax)
39
40
    def calc_direct_bias(self, neutral_words='professions', c=None):
41
        if isinstance(neutral_words, str) and neutral_words == 'professions':
42
            return super().calc_direct_bias(
43
                self._data['neutral_profession_names'], c)
44
        else:
45
            return super().calc_direct_bias(neutral_words)
46
47
    def debias(self, method='hard', neutral_words=None, equality_sets=None,
48
               inplace=True):
49
        # pylint: disable=C0301
50
        if method in ['hard', 'neutralize']:
51
            if neutral_words is None:
52
                neutral_words = self._data['neutral_words']
53
54
        if method == 'hard' and equality_sets is None:
55
            equality_sets = self._data['definitional_pairs']
56
57
            if not self.only_lower:
58
                assert all(len(equality_set) == 2
59
                           for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False'
60
                # TODO: refactor
61
                equality_sets = {(candidate1, candidate2)
62
                                 for word1, word2 in equality_sets
63
                                 for candidate1, candidate2 in zip(generate_one_word_forms(word1),
64
                                                                   generate_one_word_forms(word2))}
65
66
        return super().debias(method, neutral_words, equality_sets,
67
                              inplace)
68
69
    def learn_full_specific_words(self, seed_specific_words='bolukbasi',
70
                                  max_non_specific_examples=None,
71
                                  debug=None):
72
        if seed_specific_words == 'bolukbasi':
73
            seed_specific_words = self._data['specific_seed']
74
75
        return super().learn_full_specific_words(seed_specific_words,
76
                                                 max_non_specific_examples,
77
                                                 debug)
78