| 1 |  |  | import copy | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | from .core import BiasWordsEmbedding | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from .data import BOLUKBASI_DATA | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from .utils import generate_one_word_forms, generate_words_forms | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | class GenderBiasWE(BiasWordsEmbedding): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     def __init__(self, model, only_lower=True): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |         super().__init__(model, only_lower) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |         self._initialize_data() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |         self._identify_direction('he', 'she', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |                                  self._data['definitional_pairs'], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |                                  'pca') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     def _initialize_data(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |         self._data = copy.deepcopy(BOLUKBASI_DATA['gender']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         if not self.only_lower: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |             self._data['specific_full_with_definitional'] = \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |                 generate_words_forms(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |                                      ._data['specific_full_with_definitional'])  # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         for key in self._data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |             # if self._data[key] is list of strings | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |             if (isinstance(self._data[key], list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |                     and self._data[key] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |                     and isinstance(self._data[key][0], str)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |                 self._data[key] = (self._filter_words_by_model(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |                                                                ._data[key])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         self._data['neutral_words'] = self._extract_neutral_words(self | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |                                                                   ._data['specific_full_with_definitional'])  # pylint: disable=C0301 | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 35 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 36 |  |  |     def calc_direct_bias(self, neutral_words='professions', c=None): | 
            
                                                                        
                            
            
                                    
            
            
                | 37 |  |  |         if isinstance(neutral_words, str) and neutral_words == 'professions': | 
            
                                                                        
                            
            
                                    
            
            
                | 38 |  |  |             return super().calc_direct_bias( | 
            
                                                                        
                            
            
                                    
            
            
                | 39 |  |  |                 self._data['neutral_profession_names'], c) | 
            
                                                                        
                            
            
                                    
            
            
                | 40 |  |  |         else: | 
            
                                                                        
                            
            
                                    
            
            
                | 41 |  |  |             return super().calc_direct_bias(neutral_words) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     def debias(self, method='hard', neutral_words=None, equality_sets=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |                inplace=True, verbose=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         # pylint: disable=C0301 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         if method in ['hard', 'neutralize']: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             if neutral_words is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                 neutral_words = self._data['neutral_words'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         if method == 'hard' and equality_sets is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |             equality_sets = self._data['definitional_pairs'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |             if not self.only_lower: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                 assert all(len(equality_set) == 2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                            for equality_set in equality_sets), 'currently supporting only equality pairs if only_lower is False' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |                 # TODO: refactor | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                 equality_sets = {(candidate1, candidate2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |                                  for word1, word2 in equality_sets | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |                                  for candidate1, candidate2 in zip(generate_one_word_forms(word1), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |                                                                    generate_one_word_forms(word2))} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         return super().debias(method, neutral_words, equality_sets, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |                               inplace, verbose) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |     def learn_full_specific_words(self, seed_specific_words='bolukbasi', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |                                   max_non_specific_examples=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |                                   debug=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         if seed_specific_words == 'bolukbasi': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |             seed_specific_words = self._data['specific_seed'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         return super().learn_full_specific_words(seed_specific_words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |                                                  max_non_specific_examples, | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 73 |  |  |                                                  debug) | 
            
                                                        
            
                                    
            
            
                | 74 |  |  |  |