Completed
Push — master ( 13dc98...362bd5 )
by Shlomi
25s queued 12s
created

ethically.we.weat.calc_all_weat()   F

Complexity

Conditions 15

Size

Total Lines 100
Code Lines 56

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 56
dl 0
loc 100
rs 2.9998
c 0
b 0
f 0
cc 15
nop 6

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like ethically.we.weat.calc_all_weat() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Compute WEAT score of a Word Embedding.
3
4
WEAT is a bias measurement method for word embedding,
5
which is inspired by the `IAT <https://en.wikipedia.org/wiki/Implicit-association_test>`_
6
(Implicit Association Test) for humans.
7
It measures the similarity between two sets of *target words*
8
(e.g., programmer, engineer, scientist, ... and nurse, teacher, librarian, ...)
9
and two sets of *attribute words* (e.g., man, male, ... and woman, female ...).
10
A p-value is calculated using a permutation-test.
11
12
Reference:
13
    - Caliskan, A., Bryson, J. J., & Narayanan, A. (2017).
14
      `Semantics derived automatically
15
      from language corpora contain human-like biases
16
      <http://opus.bath.ac.uk/55288/>`_.
17
      Science, 356(6334), 183-186.
18
19
.. important::
20
    The effect size and pvalue in the WEAT have
21
    entirely different meaning from those reported in IATs (original finding).
22
    Refer to the paper for more details.
23
24
Stimulus and original finding from:
25
26
- [0, 1, 2]
27
  A. G. Greenwald, D. E. McGhee, J. L. Schwartz,
28
  Measuring individual differences in implicit cognition:
29
  the implicit association test.,
30
  Journal of personality and social psychology 74, 1464 (1998).
31
32
- [3, 4]:
33
  M. Bertrand, S. Mullainathan, Are Emily and Greg more employable
34
  than Lakisha and Jamal? a field experiment on labor market discrimination,
35
  The American Economic Review 94, 991 (2004).
36
37
- [5, 6, 9]:
38
  B. A. Nosek, M. Banaji, A. G. Greenwald, Harvesting implicit group attitudes
39
  and beliefs from a demonstration web site.,
40
  Group Dynamics: Theory, Research, and Practice 6, 101 (2002).
41
42
- [7]:
43
  B. A. Nosek, M. R. Banaji, A. G. Greenwald, Math=male, me=female,
44
  therefore math≠me.,
45
  Journal of Personality and Social Psychology 83, 44 (2002).
46
47
- [8]
48
  P. D. Turney, P. Pantel, From frequency to meaning:
49
  Vector space models of semantics,
50
  Journal of Artificial Intelligence Research 37, 141 (2010).
51
"""
52
53
# pylint: disable=C0301
54
55
import copy
56
import random
57
import warnings
58
59
import numpy as np
60
import pandas as pd
61
from mlxtend.evaluate import permutation_test
62
63
from ethically.consts import RANDOM_STATE
64
from ethically.utils import _warning_setup
65
from ethically.we.data import WEAT_DATA
66
from ethically.we.utils import (
67
    assert_gensim_keyed_vectors, cosine_similarities_by_words,
68
)
69
70
71
FILTER_BY_OPTIONS = ['model', 'data']
72
RESULTS_DF_COLUMNS = ['Target words', 'Attrib. words',
73
                      'Nt', 'Na', 's', 'd', 'p']
74
PVALUE_METHODS = ['exact', 'approximate']
75
PVALUE_DEFUALT_METHOD = 'exact'
76
ORIGINAL_DF_COLUMNS = ['original_' + key for key in ['N', 'd', 'p']]
77
WEAT_WORD_SETS = ['first_target', 'second_target',
78
                  'first_attribute', 'second_attribute']
79
PVALUE_EXACT_WARNING_LEN = 10
80
81
_warning_setup()
82
83
84
def _calc_association_target_attributes(model, target_word,
85
                                        first_attribute_words,
86
                                        second_attribute_words):
87
    # pylint: disable=line-too-long
88
89
    assert_gensim_keyed_vectors(model)
90
91
    with warnings.catch_warnings():
92
        warnings.simplefilter('ignore', FutureWarning)
93
94
        first_mean = (cosine_similarities_by_words(model,
95
                                                   target_word,
96
                                                   first_attribute_words)
97
                      .mean())
98
99
        second_mean = (cosine_similarities_by_words(model,
100
                                                    target_word,
101
                                                    second_attribute_words)
102
                       .mean())
103
104
    return first_mean - second_mean
105
106
107
def _calc_association_all_targets_attributes(model, target_words,
108
                                             first_attribute_words,
109
                                             second_attribute_words):
110
    return [_calc_association_target_attributes(model, target_word,
111
                                                first_attribute_words,
112
                                                second_attribute_words)
113
            for target_word in target_words]
114
115
116
def _calc_weat_score(model,
117
                     first_target_words, second_target_words,
118
                     first_attribute_words, second_attribute_words):
119
120
    (first_associations,
121
     second_associations) = _calc_weat_associations(model,
122
                                                    first_target_words,
123
                                                    second_target_words,
124
                                                    first_attribute_words,
125
                                                    second_attribute_words)
126
127
    return sum(first_associations) - sum(second_associations)
128
129
130
def _calc_weat_pvalue(first_associations, second_associations,
131
                      method=PVALUE_DEFUALT_METHOD):
132
133
    if method not in PVALUE_METHODS:
134
        raise ValueError('method should be one of {}, {} was given'.format(
135
            PVALUE_METHODS, method))
136
137
    pvalue = permutation_test(first_associations, second_associations,
138
                              func=lambda x, y: sum(x) - sum(y),
139
                              method=method,
140
                              seed=RANDOM_STATE)  # if exact - no meaning
141
    return pvalue
142
143
144
def _calc_weat_associations(model,
145
                            first_target_words, second_target_words,
146
                            first_attribute_words, second_attribute_words):
147
148
    assert len(first_target_words) == len(second_target_words)
149
    assert len(first_attribute_words) == len(second_attribute_words)
150
151
    first_associations = _calc_association_all_targets_attributes(model,
152
                                                                  first_target_words,
153
                                                                  first_attribute_words,
154
                                                                  second_attribute_words)
155
156
    second_associations = _calc_association_all_targets_attributes(model,
157
                                                                   second_target_words,
158
                                                                   first_attribute_words,
159
                                                                   second_attribute_words)
160
161
    return first_associations, second_associations
162
163
164
def _filter_by_data_weat_stimuli(stimuli):
165
    """Filter WEAT stimuli words if there is a `remove` key.
166
167
    Some of the words from Caliskan et al. (2017) seems as not being used.
168
169
    Modifiy the stimuli object in place.
170
    """
171
172
    for group in stimuli:
173
        if 'remove' in stimuli[group]:
174
            words_to_remove = stimuli[group]['remove']
175
            stimuli[group]['words'] = [word for word in stimuli[group]['words']
176
                                       if word not in words_to_remove]
177
178
179
def _sample_if_bigger(seq, length):
180
    random.seed(RANDOM_STATE)
181
    if len(seq) > length:
182
        seq = random.sample(seq, length)
183
    return seq
184
185
186
def _filter_by_model_weat_stimuli(stimuli, model):
187
    """Filter WEAT stimuli words if they are not exists in the model.
188
189
    Modifiy the stimuli object in place.
190
    """
191
192
    for group_category in ['target', 'attribute']:
193
        first_group = 'first_' + group_category
194
        second_group = 'second_' + group_category
195
196
        first_words = [word for word in stimuli[first_group]['words']
197
                       if word in model]
198
        second_words = [word for word in stimuli[second_group]['words']
199
                        if word in model]
200
201
        min_len = min(len(first_words), len(second_words))
202
203
        # sample to make the first and second word set
204
        # with equal size
205
        first_words = _sample_if_bigger(first_words, min_len)
206
        second_words = _sample_if_bigger(second_words, min_len)
207
208
        first_words.sort()
209
        second_words.sort()
210
211
        stimuli[first_group]['words'] = first_words
212
        stimuli[second_group]['words'] = second_words
213
214
215
def _filter_weat_data(weat_data, model, filter_by):
216
    """inplace."""
217
218
    if filter_by not in FILTER_BY_OPTIONS:
219
        raise ValueError('filter_by should be one of {}, {} was given'.format(
220
            FILTER_BY_OPTIONS, filter_by))
221
222
    if filter_by == 'data':
223
        for stimuli in weat_data:
224
            _filter_by_data_weat_stimuli(stimuli)
225
226
    elif filter_by == 'model':
227
        for stimuli in weat_data:
228
            _filter_by_model_weat_stimuli(stimuli, model)
229
230
231
def calc_single_weat(model,
232
                     first_target, second_target,
233
                     first_attribute, second_attribute,
234
                     with_pvalue=True, pvalue_kwargs=None):
235
    """
236
    Calc the WEAT result of a word embedding.
237
238
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
239
    :param dict first_target: First target words list and its name
240
    :param dict second_target: Second target words list and its name
241
    :param dict first_attribute: First attribute words list and its name
242
    :param dict second_attribute: Second attribute words list and its name
243
    :param bool with_pvalue: Whether to calculate the p-value of the
244
                             WEAT score (might be computationally expensive)
245
    :return: WEAT result (score, size effect, Nt, Na and p-value)
246
    """
247
248
    if pvalue_kwargs is None:
249
        pvalue_kwargs = {}
250
251
    (first_associations,
252
     second_associations) = _calc_weat_associations(model,
253
                                                    first_target['words'],
254
                                                    second_target['words'],
255
                                                    first_attribute['words'],
256
                                                    second_attribute['words'])
257
258
    if first_associations and second_associations:
259
        score = sum(first_associations) - sum(second_associations)
260
        std_dev = np.std(first_associations + second_associations, ddof=0)
261
        effect_size = ((np.mean(first_associations) - np.mean(second_associations))
262
                       / std_dev)
263
264
        pvalue = None
265
        if with_pvalue:
266
            pvalue = _calc_weat_pvalue(first_associations,
267
                                       second_associations,
268
                                       **pvalue_kwargs)
269
    else:
270
        score, std_dev, effect_size, pvalue = None, None, None, None
271
272
    return {'Target words': '{} vs. {}'.format(first_target['name'],
273
                                               second_target['name']),
274
            'Attrib. words': '{} vs. {}'.format(first_attribute['name'],
275
                                                second_attribute['name']),
276
            's': score,
277
            'd': effect_size,
278
            'p': pvalue,
279
            'Nt': '{}x2'.format(len(first_target['words'])),
280
            'Na': '{}x2'.format(len(first_attribute['words']))}
281
282
283
def calc_weat_pleasant_unpleasant_attribute(model,
284
                                            first_target, second_target,
285
                                            with_pvalue=True, pvalue_kwargs=None):
286
    """
287
    Calc the WEAT result with pleasent vs. unpleasant attributes.
288
289
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
290
    :param dict first_target: First target words list and its name
291
    :param dict second_target: Second target words list and its name
292
    :param bool with_pvalue: Whether to calculate the p-value of the
293
                             WEAT score (might be computationally expensive)
294
    :return: WEAT result (score, size effect, Nt, Na and p-value)
295
    """
296
297
    weat_data = {'first_attribute': copy.deepcopy(WEAT_DATA[0]['first_attribute']),
298
                 'second_attribute': copy.deepcopy(WEAT_DATA[0]['second_attribute']),
299
                 'first_target': first_target,
300
                 'second_target': second_target}
301
302
    _filter_by_model_weat_stimuli(weat_data, model)
303
304
    if pvalue_kwargs is None:
305
        pvalue_kwargs = {}
306
307
    return calc_single_weat(model,
308
                            **weat_data,
309
                            with_pvalue=with_pvalue, pvalue_kwargs=pvalue_kwargs)
310
311
312
def calc_all_weat(model, weat_data='caliskan', filter_by='model',
313
                  with_original_finding=False,
314
                  with_pvalue=True, pvalue_kwargs=None):
315
    """
316
    Calc the WEAT results of a word embedding on multiple cases.
317
318
    Note that for the effect size and pvalue in the WEAT have
319
    entirely different meaning from those reported in IATs (original finding).
320
    Refer to the paper for more details.
321
322
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
323
    :param dict weat_data: WEAT cases data.
324
                           - If `'caliskan'` (default) then all
325
                              the experiments from the original will be used.
326
                           - If an interger, then the specific experiment by index
327
                             from the original paper will be used.
328
                           - If a interger, then tje specific experiments by indices
329
                             from the original paper will be used.
330
331
    :param bool filter_by: Whether to filter the word lists
332
                           by the `model` (`'model'`)
333
                           or by the `remove` key in `weat_data` (`'data'`).
334
    :param bool with_original_finding: Show the origina
335
    :param bool with_pvalue: Whether to calculate the p-value of the
336
                             WEAT results (might be computationally expensive)
337
    :return: :class:`pandas.DataFrame` of WEAT results
338
             (score, size effect, Nt, Na and p-value)
339
    """
340
341
    if weat_data == 'caliskan':
342
        weat_data = WEAT_DATA
343
    elif isinstance(weat_data, int):
344
        index = weat_data
345
        weat_data = WEAT_DATA[index:index + 1]
346
    elif isinstance(weat_data, tuple):
347
        weat_data = [WEAT_DATA[index] for index in weat_data]
348
349
    if (not pvalue_kwargs
350
            or pvalue_kwargs['method'] == PVALUE_DEFUALT_METHOD):
351
        max_word_set_len = max(len(stimuli[ws]['words'])
352
                               for stimuli in weat_data
353
                               for ws in WEAT_WORD_SETS)
354
        if max_word_set_len > PVALUE_EXACT_WARNING_LEN:
355
            warnings.warn('At least one stimuli has a word set bigger'
356
                          ' than {}, and the computation might take a while.'
357
                          ' Consider using \'exact\' as method'
358
                          ' for pvalue_kwargs.'.format(PVALUE_EXACT_WARNING_LEN))
359
360
    actual_weat_data = copy.deepcopy(weat_data)
361
362
    _filter_weat_data(actual_weat_data,
363
                      model,
364
                      filter_by)
365
366
    if weat_data != actual_weat_data:
367
        warnings.warn('Given weat_data was filterd by {}.'
368
                      .format(filter_by))
369
370
    results = []
371
    for stimuli in actual_weat_data:
372
        result = calc_single_weat(model,
373
                                  stimuli['first_target'],
374
                                  stimuli['second_target'],
375
                                  stimuli['first_attribute'],
376
                                  stimuli['second_attribute'],
377
                                  with_pvalue, pvalue_kwargs)
378
379
        # TODO: refactor - check before if one group is without words
380
        # because of the filtering
381
        if not all(group['words'] for group in stimuli.values()
382
                   if 'words' in group):
383
            result['score'] = None
384
            result['effect_size'] = None
385
            result['pvalue'] = None
386
387
        result['stimuli'] = stimuli
388
389
        if with_original_finding:
390
            result.update({'original_' + k: v
391
                           for k, v in stimuli['original_finding'].items()})
392
        results.append(result)
393
394
    results_df = pd.DataFrame(results)
395
    results_df = results_df.replace('nan', None)
396
    results_df = results_df.fillna('')
397
398
    # if not results_df.empty:
399
    cols = RESULTS_DF_COLUMNS[:]
400
    if with_original_finding:
401
        cols += ORIGINAL_DF_COLUMNS
402
    if not with_pvalue:
403
        cols.remove('p')
404
    else:
405
        results_df['p'] = results_df['p'].apply(lambda pvalue: '{:0.1e}'.format(pvalue)  # pylint: disable=W0108
406
                                                if pvalue else pvalue)
407
408
    results_df = results_df[cols]
409
    results_df = results_df.round(4)
410
411
    return results_df
412