Passed
Pull Request — dev (#23)
by Shlomi
02:43
created

ethically.we.weat.calc_all_weat()   F

Complexity

Conditions 15

Size

Total Lines 100
Code Lines 56

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 56
dl 0
loc 100
rs 2.9998
c 0
b 0
f 0
cc 15
nop 6

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like ethically.we.weat.calc_all_weat() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Compute WEAT score of a Word Embedding.
3
4
WEAT is a bias measurement method for word embedding,
5
which is inspired by the `IAT <https://en.wikipedia.org/wiki/Implicit-association_test>`_
6
(Implicit Association Test) for humans.
7
It measures the similarity between two sets of *target words*
8
(e.g., programmer, engineer, scientist, ... and nurse, teacher, librarian, ...)
9
and two sets of *attribute words* (e.g., man, male, ... and woman, female ...).
10
A p-value is calculated using a permutation-test.
11
12
Reference:
13
    - Caliskan, A., Bryson, J. J., & Narayanan, A. (2017).
14
      `Semantics derived automatically
15
      from language corpora contain human-like biases
16
      <http://opus.bath.ac.uk/55288/>`_.
17
      Science, 356(6334), 183-186.
18
19
.. important::
20
    The effect size and pvalue in the WEAT have
21
    entirely different meaning from those reported in IATs (original finding).
22
    Refer to the paper for more details.
23
24
Stimulus and original finding from:
25
26
- [0, 1, 2]
27
  A. G. Greenwald, D. E. McGhee, J. L. Schwartz,
28
  Measuring individual differences in implicit cognition:
29
  the implicit association test.,
30
  Journal of personality and social psychology 74, 1464 (1998).
31
32
- [3, 4]:
33
  M. Bertrand, S. Mullainathan, Are Emily and Greg more employable
34
  than Lakisha and Jamal? a field experiment on labor market discrimination,
35
  The American Economic Review 94, 991 (2004).
36
37
- [5, 6, 9]:
38
  B. A. Nosek, M. Banaji, A. G. Greenwald, Harvesting implicit group attitudes
39
  and beliefs from a demonstration web site.,
40
  Group Dynamics: Theory, Research, and Practice 6, 101 (2002).
41
42
- [7]:
43
  B. A. Nosek, M. R. Banaji, A. G. Greenwald, Math=male, me=female,
44
  therefore math≠me.,
45
  Journal of Personality and Social Psychology 83, 44 (2002).
46
47
- [8]
48
  P. D. Turney, P. Pantel, From frequency to meaning:
49
  Vector space models of semantics,
50
  Journal of Artificial Intelligence Research 37, 141 (2010).
51
"""
52
53
# pylint: disable=C0301
54
55
import copy
56
import random
57
import warnings
58
59
import numpy as np
60
import pandas as pd
61
from mlxtend.evaluate import permutation_test
62
63
from ethically.consts import RANDOM_STATE
64
from ethically.utils import _warning_setup
65
from ethically.we.data import WEAT_DATA
66
from ethically.we.utils import (
67
    assert_gensim_keyed_vectors, cosine_similarities_by_words,
68
)
69
70
71
FILTER_BY_OPTIONS = ['model', 'data']
72
RESULTS_DF_COLUMNS = ['Target words', 'Attrib. words',
73
                      'Nt', 'Na', 's', 'd', 'p']
74
PVALUE_METHODS = ['exact', 'approximate']
75
PVALUE_DEFUALT_METHOD = 'exact'
76
ORIGINAL_DF_COLUMNS = ['original_' + key for key in ['N', 'd', 'p']]
77
WEAT_WORD_SETS = ['first_target', 'second_target',
78
                  'first_attribute', 'second_attribute']
79
PVALUE_EXACT_WARNING_LEN = 10
80
81
_warning_setup()
82
83
84
def _calc_association_target_attributes(model, target_word,
85
                                        first_attribute_words,
86
                                        second_attribute_words):
87
    # pylint: disable=line-too-long
88
89
    assert_gensim_keyed_vectors(model)
90
91
    with warnings.catch_warnings():
92
        warnings.simplefilter('ignore', FutureWarning)
93
        first_mean = (cosine_similarities_by_words(model,
94
                                                   target_word,
95
                                                   first_attribute_words)
96
                      .mean())
97
        second_mean = (cosine_similarities_by_words(model,
98
                                                    target_word,
99
                                                    second_attribute_words)
100
                       .mean())
101
102
    return first_mean - second_mean
103
104
105
def _calc_association_all_targets_attributes(model, target_words,
106
                                             first_attribute_words,
107
                                             second_attribute_words):
108
    return [_calc_association_target_attributes(model, target_word,
109
                                                first_attribute_words,
110
                                                second_attribute_words)
111
            for target_word in target_words]
112
113
114
def _calc_weat_score(model,
115
                     first_target_words, second_target_words,
116
                     first_attribute_words, second_attribute_words):
117
118
    (first_associations,
119
     second_associations) = _calc_weat_associations(model,
120
                                                    first_target_words,
121
                                                    second_target_words,
122
                                                    first_attribute_words,
123
                                                    second_attribute_words)
124
125
    return sum(first_associations) - sum(second_associations)
126
127
128
def _calc_weat_pvalue(first_associations, second_associations,
129
                      method=PVALUE_DEFUALT_METHOD):
130
131
    if method not in PVALUE_METHODS:
132
        raise ValueError('method should be one of {}, {} was given'.format(
133
            PVALUE_METHODS, method))
134
135
    pvalue = permutation_test(first_associations, second_associations,
136
                              func='x_mean > y_mean',
137
                              method=method,
138
                              seed=RANDOM_STATE)  # if exact - no meaning
139
    return pvalue
140
141
142
def _calc_weat_associations(model,
143
                            first_target_words, second_target_words,
144
                            first_attribute_words, second_attribute_words):
145
146
    assert len(first_target_words) == len(second_target_words)
147
    assert len(first_attribute_words) == len(second_attribute_words)
148
149
    first_associations = _calc_association_all_targets_attributes(model,
150
                                                                  first_target_words,
151
                                                                  first_attribute_words,
152
                                                                  second_attribute_words)
153
154
    second_associations = _calc_association_all_targets_attributes(model,
155
                                                                   second_target_words,
156
                                                                   first_attribute_words,
157
                                                                   second_attribute_words)
158
159
    return first_associations, second_associations
160
161
162
def _filter_by_data_weat_stimuli(stimuli):
163
    """Filter WEAT stimuli words if there is a `remove` key.
164
165
    Some of the words from Caliskan et al. (2017) seems as not being used.
166
167
    Modifiy the stimuli object in place.
168
    """
169
170
    for group in stimuli:
171
        if 'remove' in stimuli[group]:
172
            words_to_remove = stimuli[group]['remove']
173
            stimuli[group]['words'] = [word for word in stimuli[group]['words']
174
                                       if word not in words_to_remove]
175
176
177
def _sample_if_bigger(seq, length):
178
    random.seed(RANDOM_STATE)
179
    if len(seq) > length:
180
        seq = random.sample(seq, length)
181
    return seq
182
183
184
def _filter_by_model_weat_stimuli(stimuli, model):
185
    """Filter WEAT stimuli words if they are not exists in the model.
186
187
    Modifiy the stimuli object in place.
188
    """
189
190
    for group_category in ['target', 'attribute']:
191
        first_group = 'first_' + group_category
192
        second_group = 'second_' + group_category
193
194
        first_words = [word for word in stimuli[first_group]['words']
195
                       if word in model]
196
        second_words = [word for word in stimuli[second_group]['words']
197
                        if word in model]
198
199
        min_len = min(len(first_words), len(second_words))
200
201
        # sample to make the first and second word set
202
        # with equal size
203
        first_words = _sample_if_bigger(first_words, min_len)
204
        second_words = _sample_if_bigger(second_words, min_len)
205
206
        first_words.sort()
207
        second_words.sort()
208
209
        stimuli[first_group]['words'] = first_words
210
        stimuli[second_group]['words'] = second_words
211
212
213
def _filter_weat_data(weat_data, model, filter_by):
214
    """inplace."""
215
216
    if filter_by not in FILTER_BY_OPTIONS:
217
        raise ValueError('filter_by should be one of {}, {} was given'.format(
218
            FILTER_BY_OPTIONS, filter_by))
219
220
    if filter_by == 'data':
221
        for stimuli in weat_data:
222
            _filter_by_data_weat_stimuli(stimuli)
223
224
    elif filter_by == 'model':
225
        for stimuli in weat_data:
226
            _filter_by_model_weat_stimuli(stimuli, model)
227
228
229
def calc_single_weat(model,
230
                     first_target, second_target,
231
                     first_attribute, second_attribute,
232
                     with_pvalue=True, pvalue_kwargs=None):
233
    """
234
    Calc the WEAT result of a word embedding.
235
236
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
237
    :param dict first_target: First target words list and its name
238
    :param dict second_target: Second target words list and its name
239
    :param dict first_attribute: First attribute words list and its name
240
    :param dict second_attribute: Second attribute words list and its name
241
    :param bool with_pvalue: Whether to calculate the p-value of the
242
                             WEAT score (might be computationally expensive)
243
    :return: WEAT result (score, size effect, Nt, Na and p-value)
244
    """
245
246
    if pvalue_kwargs is None:
247
        pvalue_kwargs = {}
248
249
    (first_associations,
250
     second_associations) = _calc_weat_associations(model,
251
                                                    first_target['words'],
252
                                                    second_target['words'],
253
                                                    first_attribute['words'],
254
                                                    second_attribute['words'])
255
256
    if first_associations and second_associations:
257
        score = sum(first_associations) - sum(second_associations)
258
        std_dev = np.std(first_associations + second_associations, ddof=0)
259
        effect_size = ((np.mean(first_associations) - np.mean(second_associations))
260
                       / std_dev)
261
262
        pvalue = None
263
        if with_pvalue:
264
            pvalue = _calc_weat_pvalue(first_associations,
265
                                       second_associations,
266
                                       **pvalue_kwargs)
267
    else:
268
        score, std_dev, effect_size, pvalue = None, None, None, None
269
270
    return {'Target words': '{} vs. {}'.format(first_target['name'],
271
                                               second_target['name']),
272
            'Attrib. words': '{} vs. {}'.format(first_attribute['name'],
273
                                                second_attribute['name']),
274
            's': score,
275
            'd': effect_size,
276
            'p': pvalue,
277
            'Nt': '{}x2'.format(len(first_target['words'])),
278
            'Na': '{}x2'.format(len(first_attribute['words']))}
279
280
281
def calc_weat_pleasant_unpleasant_attribute(model,
282
                                            first_target, second_target,
283
                                            with_pvalue=True, pvalue_kwargs=None):
284
    """
285
    Calc the WEAT result with pleasent vs. unpleasant attributes.
286
287
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
288
    :param dict first_target: First target words list and its name
289
    :param dict second_target: Second target words list and its name
290
    :param bool with_pvalue: Whether to calculate the p-value of the
291
                             WEAT score (might be computationally expensive)
292
    :return: WEAT result (score, size effect, Nt, Na and p-value)
293
    """
294
295
    weat_data = {'first_attribute': copy.deepcopy(WEAT_DATA[0]['first_attribute']),
296
                 'second_attribute': copy.deepcopy(WEAT_DATA[0]['second_attribute']),
297
                 'first_target': first_target,
298
                 'second_target': second_target}
299
300
    _filter_by_model_weat_stimuli(weat_data, model)
301
302
    if pvalue_kwargs is None:
303
        pvalue_kwargs = {}
304
305
    return calc_single_weat(model,
306
                            **weat_data,
307
                            with_pvalue=with_pvalue, pvalue_kwargs=pvalue_kwargs)
308
309
310
def calc_all_weat(model, weat_data='caliskan', filter_by='model',
311
                  with_original_finding=False,
312
                  with_pvalue=True, pvalue_kwargs=None):
313
    """
314
    Calc the WEAT results of a word embedding on multiple cases.
315
316
    Note that for the effect size and pvalue in the WEAT have
317
    entirely different meaning from those reported in IATs (original finding).
318
    Refer to the paper for more details.
319
320
    :param model: Word embedding model of ``gensim.model.KeyedVectors``
321
    :param dict weat_data: WEAT cases data.
322
                           - If `'caliskan'` (default) then all
323
                              the experiments from the original will be used.
324
                           - If an interger, then the specific experiment by index
325
                             from the original paper will be used.
326
                           - If a interger, then tje specific experiments by indices
327
                             from the original paper will be used.
328
329
    :param bool filter_by: Whether to filter the word lists
330
                           by the `model` (`'model'`)
331
                           or by the `remove` key in `weat_data` (`'data'`).
332
    :param bool with_original_finding: Show the origina
333
    :param bool with_pvalue: Whether to calculate the p-value of the
334
                             WEAT results (might be computationally expensive)
335
    :return: :class:`pandas.DataFrame` of WEAT results
336
             (score, size effect, Nt, Na and p-value)
337
    """
338
339
    if weat_data == 'caliskan':
340
        weat_data = WEAT_DATA
341
    elif isinstance(weat_data, int):
342
        index = weat_data
343
        weat_data = WEAT_DATA[index:index + 1]
344
    elif isinstance(weat_data, tuple):
345
        weat_data = [WEAT_DATA[index] for index in weat_data]
346
347
    if (not pvalue_kwargs
348
            or pvalue_kwargs['method'] == PVALUE_DEFUALT_METHOD):
349
        max_word_set_len = max(len(stimuli[ws]['words'])
350
                               for stimuli in weat_data
351
                               for ws in WEAT_WORD_SETS)
352
        if max_word_set_len > PVALUE_EXACT_WARNING_LEN:
353
            warnings.warn('At least one stimuli has a word set bigger'
354
                          ' than {}, and the computation might take a while.'
355
                          ' Consider using \'exact\' as method'
356
                          ' for pvalue_kwargs.'.format(PVALUE_EXACT_WARNING_LEN))
357
358
    actual_weat_data = copy.deepcopy(weat_data)
359
360
    _filter_weat_data(actual_weat_data,
361
                      model,
362
                      filter_by)
363
364
    if weat_data != actual_weat_data:
365
        warnings.warn('Given weat_data was filterd by {}.'
366
                      .format(filter_by))
367
368
    results = []
369
    for stimuli in actual_weat_data:
370
        result = calc_single_weat(model,
371
                                  stimuli['first_target'],
372
                                  stimuli['second_target'],
373
                                  stimuli['first_attribute'],
374
                                  stimuli['second_attribute'],
375
                                  with_pvalue, pvalue_kwargs)
376
377
        # TODO: refactor - check before if one group is without words
378
        # because of the filtering
379
        if not all(group['words'] for group in stimuli.values()
380
                   if 'words' in group):
381
            result['score'] = None
382
            result['effect_size'] = None
383
            result['pvalue'] = None
384
385
        result['stimuli'] = stimuli
386
387
        if with_original_finding:
388
            result.update({'original_' + k: v
389
                           for k, v in stimuli['original_finding'].items()})
390
        results.append(result)
391
392
    results_df = pd.DataFrame(results)
393
    results_df = results_df.replace('nan', None)
394
    results_df = results_df.fillna('')
395
396
    # if not results_df.empty:
397
    cols = RESULTS_DF_COLUMNS[:]
398
    if with_original_finding:
399
        cols += ORIGINAL_DF_COLUMNS
400
    if not with_pvalue:
401
        cols.remove('p')
402
    else:
403
        results_df['p'] = results_df['p'].apply(lambda pvalue: '{:0.1e}'.format(pvalue)  # pylint: disable=W0108
404
                                                if pvalue else pvalue)
405
406
    results_df = results_df[cols]
407
    results_df = results_df.round(4)
408
409
    return results_df
410