| 1 |  |  | import copy | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | from .core import BiasWordsEmbedding | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from .data import BOLUKBASI_DATA | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from .utils import generate_one_word_forms, generate_words_forms | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | class GenderBiasWE(BiasWordsEmbedding): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     """Audit and Adjust the Gender Bias in English Words Embedding. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     :param model: Words embedding model of ``gensim.model.KeyedVectors`` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     :param bool only_lower: Whether the words embedding contrains | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |                             only lower case words | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     :param bool verbose: Set vebosity | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     def __init__(self, model, only_lower=False, verbose=False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |                  identify_direction=True): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         super().__init__(model, only_lower, verbose) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         self._initialize_data() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         if identify_direction: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |             self._identify_direction('she', 'he', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |                                      self._data['definitional_pairs'], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |                                      'pca') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |     def _initialize_data(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         if not self.only_lower: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |             self._data['specific_full_with_definitional'] = \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |                 generate_words_forms(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |                                      ._data['specific_full_with_definitional'])  # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         for key in self._data['word_group_keys']: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |             self._data[key] = (self._filter_words_by_model(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |                                                            ._data[key])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         self._data['neutral_words'] = self._extract_neutral_words(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |                                                                   ._data['specific_full_with_definitional'])  # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         self._data['neutral_words'].sort() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         self._data['word_group_keys'].append('neutral_words') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     def plot_projection_scores(self, words='professions', n_extreme=10, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |                                ax=None, axis_projection_step=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         if words == 'professions': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |             words = self._data['profession_names'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         return super().plot_projection_scores(words, n_extreme, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |                                               ax, axis_projection_step) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |     def plot_dist_projections_on_direction(self, word_groups='bolukbasi', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |                                            ax=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         if word_groups == 'bolukbasi': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             word_groups = {key: self._data[key] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                            for key in self._data['word_group_keys']} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         return super().plot_dist_projections_on_direction(word_groups, ax) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     def plot_bias_across_words_embeddings(cls, words_embedding_bias_dict, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |                                           ax=None, scatter_kwargs=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         # pylint: disable=W0221 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         words = BOLUKBASI_DATA['gender']['neutral_profession_names'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         # TODO: is it correct for inhertence of class method? | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         super(cls, cls).plot_bias_across_words_embeddings(words_embedding_bias_dict,  # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |                                                           words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |                                                           ax, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |                                                           scatter_kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     def calc_direct_bias(self, neutral_words='professions', c=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         if isinstance(neutral_words, str) and neutral_words == 'professions': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |             return super().calc_direct_bias( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |                 self._data['neutral_profession_names'], c) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |             return super().calc_direct_bias(neutral_words) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     def generate_closest_words_indirect_bias(self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |                                              neutral_positive_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |                                              neutral_negative_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |                                              words='professions', n_extreme=5): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |         if words == 'professions': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |             words = self._data['profession_names'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         return super().generate_closest_words_indirect_bias(neutral_positive_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |                                                             neutral_negative_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |                                                             words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |                                                             n_extreme=n_extreme) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |     def debias(self, method='hard', neutral_words=None, equality_sets=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |                inplace=True): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         if method in ['hard', 'neutralize']: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |             if neutral_words is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |                 neutral_words = self._data['neutral_words'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         if method == 'hard' and equality_sets is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |             equality_sets = self._data['definitional_pairs'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |             if not self.only_lower: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |                 assert all(len(equality_set) == 2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |                            for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |                 # TODO: refactor | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |                 equality_sets = {(candidate1, candidate2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |                                  for word1, word2 in equality_sets | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |                                  for candidate1, candidate2 in zip(generate_one_word_forms(word1), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |                                                                    generate_one_word_forms(word2))} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         return super().debias(method, neutral_words, equality_sets, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |                               inplace) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |     def learn_full_specific_words(self, seed_specific_words='bolukbasi', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |                                   max_non_specific_examples=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |                                   debug=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         if seed_specific_words == 'bolukbasi': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |             seed_specific_words = self._data['specific_seed'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         return super().learn_full_specific_words(seed_specific_words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |                                                  max_non_specific_examples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |                                                  debug) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |  |