|
1
|
|
|
import random |
|
2
|
|
|
from functools import reduce |
|
3
|
|
|
from types import MethodType |
|
4
|
|
|
from collections import Counter, OrderedDict |
|
5
|
|
|
|
|
6
|
|
|
import attr |
|
7
|
|
|
import numpy as np |
|
8
|
|
|
from scipy.stats import entropy as scipy_entropy |
|
9
|
|
|
|
|
10
|
|
|
|
|
11
|
|
|
import logging |
|
12
|
|
|
logger = logging.getLogger(__name__) |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
from .definitions import SCALE_PLACEMENT, DISCREETIZATION |
|
16
|
|
|
|
|
17
|
|
|
|
|
18
|
|
|
def distr(ds_scheme, psm): |
|
19
|
|
|
outlet_id2name = OrderedDict([(_id, name) for name, _id in psm.scale.items()]) |
|
20
|
|
|
datapoint_labels = [ds_scheme.poster_name2ideology_label[outlet_id2name[x]] for x in psm.datapoint_ids] |
|
21
|
|
|
c = Counter(datapoint_labels) |
|
22
|
|
|
n = sum(c.values()) |
|
23
|
|
|
return [c[class_name] / float(n) for class_name, _ in ds_scheme] |
|
24
|
|
|
|
|
25
|
|
|
def init_population(self, class_names, datapoint_ids, pool_size): |
|
26
|
|
|
"""Random genes creation""" |
|
27
|
|
|
self.population.init_random(datapoint_ids, pool_size, len(class_names)-1) |
|
28
|
|
|
self.__class_names = class_names |
|
29
|
|
|
|
|
30
|
|
|
def evolve(self, nb_generations, prob=0.35, class_names=None): |
|
31
|
|
|
"""Remembers inputted class_names in 'init_population' above""" |
|
32
|
|
|
self.population.evolve(nb_generations, prob=prob) |
|
33
|
|
|
if not class_names: |
|
34
|
|
|
class_names = self.__class_names |
|
35
|
|
|
return DiscreetizationScheme.from_design(self.population.pool[0], [(k, v) for k, v in self.scale.items()], class_names) |
|
36
|
|
|
|
|
37
|
|
|
|
|
38
|
|
|
# def evolve(self, class_names, datapoint_ids, vectors_length, pool_size, prob=0.2, max_generation=100): |
|
39
|
|
|
# self.population.evolve(datapoint_ids, vectors_length, pool_size, prob=prob, max_generation=max_generation) |
|
40
|
|
|
# return DiscreetizationScheme.from_design(self.population.pool[0], [(k,v) for k,v in self.scale.items()], class_names) |
|
41
|
|
|
|
|
42
|
|
|
class PoliticalSpectrumManager(object): |
|
43
|
|
|
_instance = None |
|
44
|
|
|
def __new__(cls): |
|
45
|
|
|
if not cls._instance: |
|
46
|
|
|
cls._instance = PoliticalSpectrum(SCALE_PLACEMENT, DISCREETIZATION) |
|
47
|
|
|
cls._instance.population = Population(cls._instance) |
|
48
|
|
|
cls._instance.init_population = MethodType(init_population, cls._instance) |
|
49
|
|
|
cls._instance.evolve = MethodType(evolve, cls._instance) |
|
50
|
|
|
return cls._instance |
|
51
|
|
|
|
|
52
|
|
|
|
|
53
|
|
|
def _correct_order(instance, attribute, converted_value): |
|
54
|
|
|
instance._schemes = {k: OrderedDict([(class_label, [_ for _ in instance.scale if _ in outlet_names]) for class_label, outlet_names in scheme.items()]) |
|
55
|
|
|
for k, scheme in converted_value.items()} |
|
56
|
|
|
|
|
57
|
|
|
def build_schemes(schemes_dict, scale=None): |
|
58
|
|
|
for scheme_name, scheme_data in schemes_dict.items(): |
|
59
|
|
|
if type(scheme_data) == dict: |
|
60
|
|
|
if not all(type(x) == int for x in scheme_data['design']): |
|
61
|
|
|
raise ValueError |
|
62
|
|
|
yield scheme_name, DiscreetizationScheme.from_design(scheme_data['design'], scale, scheme_data['class_names']) |
|
63
|
|
|
else: |
|
64
|
|
|
yield scheme_name, DiscreetizationScheme(scheme_data) |
|
65
|
|
|
|
|
66
|
|
|
@attr.s |
|
67
|
|
|
class PoliticalSpectrum(object): |
|
68
|
|
|
scale = attr.ib(init=True, converter=lambda x: OrderedDict(x), repr=True) |
|
69
|
|
|
_schemes = attr.ib(init=True, converter=lambda x: {name: scheme for name, scheme in build_schemes(x, scale=SCALE_PLACEMENT)}) |
|
70
|
|
|
|
|
71
|
|
|
# _schemes = attr.ib(init=True, converter=lambda x: {scheme_name: |
|
72
|
|
|
# OrderedDict([(class_name.lower().replace(' ', '_'), outlet_names_list) for class_name, outlet_names_list in discretization_design]) |
|
73
|
|
|
# for scheme_name, discretization_design in x.items()}, |
|
74
|
|
|
# validator=_correct_order, repr=True) |
|
75
|
|
|
_cur = attr.ib(init=True, default='legacy-scheme', converter=str, repr=True) |
|
76
|
|
|
datapoint_ids = attr.ib(init=True, default=[], repr=False) |
|
77
|
|
|
_outlet_id2outlet_name = attr.ib(init=False, default=attr.Factory(lambda self: OrderedDict((v, k) for k, v in self.scale.items()), takes_self=True), repr=False) |
|
78
|
|
|
|
|
79
|
|
|
# _label_id2outlet_names = attr.ib(init=False, default=attr.Factory(lambda self: OrderedDict([(bin_name.lower().replace(' ', '_'), outlet_names_list) for bin_name, outlet_names_list in |
|
80
|
|
|
# self._schemes[self._cur].values()]), takes_self=True), repr=False) |
|
81
|
|
|
# _outlet_id2outlet_name = attr.ib(init=False, default=attr.Factory(lambda self: {poster_id: poster_name for ide_bin_dict in self._schemes[self._cur].values() for |
|
82
|
|
|
# poster_name, poster_id in ide_bin_dict.items()}, takes_self=True), repr=False) |
|
83
|
|
|
# _label_id2outlet_dict = attr.ib(init=False, default=attr.Factory(lambda self: OrderedDict([('_'.join(ide.split(' ')), |
|
84
|
|
|
# OrderedDict(sorted([(outlet, outlet_id) for outlet, outlet_id in |
|
85
|
|
|
# self._schemes[self._cur][ide].items()], |
|
86
|
|
|
# key=lambda x: x[0]))) for ide in |
|
87
|
|
|
# self._schemes[self._cur].keys()]), takes_self=True), repr=False) |
|
88
|
|
|
def _key(self, scheme): |
|
89
|
|
|
return '-'.join(str(class_name) for class_name, _ in scheme) |
|
90
|
|
|
|
|
91
|
|
|
def __getitem__(self, item): |
|
92
|
|
|
if item not in self._schemes: |
|
93
|
|
|
raise KeyError( |
|
94
|
|
|
"The schemes implemented are [{}]. Requested '{}' instead.".format(', '.join(self._schemes.keys()), item)) |
|
95
|
|
|
return self._schemes[item] |
|
96
|
|
|
|
|
97
|
|
|
def __iter__(self): |
|
98
|
|
|
return ((k, v) for k, v in self._schemes.items()) |
|
|
|
|
|
|
99
|
|
|
|
|
100
|
|
|
@property |
|
101
|
|
|
def discreetization_scheme(self): |
|
102
|
|
|
return self._schemes[self._cur] |
|
103
|
|
|
|
|
104
|
|
|
@discreetization_scheme.setter |
|
105
|
|
|
def discreetization_scheme(self, scheme): |
|
106
|
|
|
if type(scheme) == str: |
|
107
|
|
|
if scheme not in self._schemes: |
|
108
|
|
|
raise KeyError("The schemes implemented are [{}]. Requested '{}' instead.".format(', '.join(self._schemes.keys()), scheme)) |
|
109
|
|
|
self._cur = scheme |
|
110
|
|
|
|
|
111
|
|
|
elif type(scheme) == list: |
|
112
|
|
|
if type(scheme[0]) == str: # [scheme_name, scheme_data] |
|
113
|
|
|
k = scheme[0] |
|
114
|
|
|
scheme = scheme[1] |
|
115
|
|
|
else: # scheme_data (automatically create name) |
|
116
|
|
|
k = self._key(scheme) |
|
117
|
|
|
self._schemes[k] = DiscreetizationScheme(scheme) |
|
118
|
|
|
self._cur = k |
|
119
|
|
|
logger.info("Registered new discreetization scheme '{}' with doc classes [{}].".format(k, ', '.join( |
|
120
|
|
|
class_name for class_name, _ in scheme))) |
|
121
|
|
|
if self.datapoint_ids: |
|
122
|
|
|
logger.info("Classes' distribution: [{}]".format( |
|
123
|
|
|
', '.join(['{:.2f}'.format(x) for x in distr(self._schemes[k], self)]))) |
|
124
|
|
|
elif type(scheme) == DiscreetizationScheme: |
|
125
|
|
|
k = self._key(scheme) |
|
126
|
|
|
self._schemes[k] = DiscreetizationScheme(scheme) |
|
127
|
|
|
self._cur = k |
|
128
|
|
|
|
|
129
|
|
|
logger.info("Registered discreetization scheme '{}' with doc classes [{}].".format(k, ', '.join(class_name for class_name, _ in scheme))) |
|
130
|
|
|
if self.datapoint_ids: |
|
131
|
|
|
logger.info("Classes' distribution: [{}]".format(', '.join(['{:.2f}'.format(x) for x in distr(self._schemes[k], self)]))) |
|
132
|
|
|
else: |
|
133
|
|
|
raise ValueError("Input should be either a string or a DiscreetizationScheme or a [str, DiscreetizationScheme] list (1st element is the name/key)") |
|
134
|
|
|
|
|
135
|
|
|
def distribution(self, scheme): |
|
136
|
|
|
return distr(scheme, self) |
|
137
|
|
|
|
|
138
|
|
|
@property |
|
139
|
|
|
def class_distribution(self): |
|
140
|
|
|
return distr(self._schemes[self._cur], self) |
|
141
|
|
|
|
|
142
|
|
|
@property |
|
143
|
|
|
def poster_id2ideology_label(self): |
|
144
|
|
|
return OrderedDict([(self.scale[name], class_label) for class_label, outlet_names in self._schemes[self._cur] for name in outlet_names]) |
|
145
|
|
|
|
|
146
|
|
|
@property |
|
147
|
|
|
def class_names(self): |
|
148
|
|
|
"""The normalized class names matching the discrete bins applied on the 10-point scale of ideological consistency""" |
|
149
|
|
|
return self._schemes[self._cur].class_names |
|
150
|
|
|
|
|
151
|
|
|
|
|
152
|
|
|
@attr.s(cmp=True, repr=True, slots=True) |
|
153
|
|
|
class DiscreetizationScheme(object): |
|
154
|
|
|
_bins = attr.ib(init=True, converter=lambda x: OrderedDict([(class_name, outlet_names_list) for class_name, outlet_names_list in x]), repr=True, cmp=True) |
|
155
|
|
|
poster_name2ideology_label = attr.ib(init=False, default=attr.Factory(lambda self: OrderedDict([(name, class_label) for class_label, outlet_names in self._bins.items() for name in outlet_names]), takes_self=True), repr=False) |
|
156
|
|
|
|
|
157
|
|
|
def __iter__(self): |
|
158
|
|
|
for class_name, items in self._bins.items(): |
|
159
|
|
|
yield class_name, items |
|
160
|
|
|
|
|
161
|
|
|
def __str__(self): |
|
162
|
|
|
return "[{}]".format(',\n'.join("('{}', [{}])".format(class_name, ', '.join("'{}'".format(name) for name in outlets_list)) for class_name, outlets_list in self)) |
|
163
|
|
|
|
|
164
|
|
|
@property |
|
165
|
|
|
def class_names(self): |
|
166
|
|
|
return list(self._bins.keys()) |
|
167
|
|
|
|
|
168
|
|
|
@class_names.setter |
|
169
|
|
|
def class_names(self, class_names): |
|
170
|
|
|
if len(class_names) != len(self._bins): |
|
171
|
|
|
raise RuntimeError("Please give equal number of class names ({} given) as the number of defined bins ({})".format(len(class_names), len(self._bins))) |
|
172
|
|
|
self._bins = OrderedDict([(class_name, outlet_list) for class_name, outlet_list in zip(class_names, self._bins.items())]) |
|
173
|
|
|
|
|
174
|
|
|
@classmethod |
|
175
|
|
|
def from_design(cls, design, scale, class_names=None): |
|
176
|
|
|
""" |
|
177
|
|
|
:param design: |
|
178
|
|
|
:param list of tuples scale: |
|
179
|
|
|
:param class_names: |
|
180
|
|
|
:return: |
|
181
|
|
|
""" |
|
182
|
|
|
if not class_names: |
|
183
|
|
|
class_names = ['bin_{}'.format(x) for x in range(len(design) + 1)] |
|
184
|
|
|
return DiscreetizationScheme([(k,v) for k,v in zip(class_names, list(Bins.from_design(design, scale)))]) |
|
185
|
|
|
|
|
186
|
|
|
def to_design(self): |
|
187
|
|
|
if len(self._bins) == 1: |
|
188
|
|
|
return [] |
|
189
|
|
|
if len(self._bins) == 2: |
|
190
|
|
|
return [len(self._bins[self.class_names[0]])] |
|
191
|
|
|
v = [] |
|
192
|
|
|
for i in range(len(self.class_names) - 1): |
|
193
|
|
|
v.append(len(self._bins[self.class_names[i]]) + sum(v)) |
|
194
|
|
|
return v |
|
195
|
|
|
|
|
196
|
|
|
def _check_nb_bins(instance, attribute, value): |
|
197
|
|
|
if not 0 < len(value) <= 100: |
|
198
|
|
|
raise ValueError("Resulted in {} bins but they should be in [1,100]".format(len(value))) |
|
199
|
|
|
if sum(len(x) for x in value) != len(reduce(lambda x,y: x.union(y), [set(_) for _ in value])): |
|
200
|
|
|
raise ValueError("Found same item in multiple bins") |
|
201
|
|
|
|
|
202
|
|
|
|
|
203
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True, frozen=True) |
|
204
|
|
|
class Bins(object): |
|
205
|
|
|
bins = attr.ib(init=True, cmp=True, repr=True, validator=_check_nb_bins) |
|
206
|
|
|
# __max_el_length = attr.ib(init=False, default=attr.Factory(lambda self: max([len(x) for outlet_names in self.bins for x in outlet_names]), takes_self), repr=False, cmp=False) |
|
207
|
|
|
|
|
208
|
|
|
def __getitem__(self, item): |
|
209
|
|
|
return self.bins[item] |
|
210
|
|
|
|
|
211
|
|
|
@classmethod |
|
212
|
|
|
def from_design(cls, design, scale_placement): |
|
213
|
|
|
""" |
|
214
|
|
|
:param design: |
|
215
|
|
|
:param list of tuples scale_placement: the ordering of outlets from liberal to conservative that the design bins (discreetizes to classes) |
|
216
|
|
|
:return: |
|
217
|
|
|
""" |
|
218
|
|
|
return Bins([[scale_placement[i][0] for i in a_range] for a_range in BinDesign(list(design)).ranges(len(scale_placement))]) |
|
219
|
|
|
|
|
220
|
|
|
def __str__(self): |
|
221
|
|
|
return '\n'.join(' '.join(self.__str_els(b, i) for b in self.bins) for i in range(max(len(x) for x in self.bins))) |
|
222
|
|
|
# for i in range(max(len(x) for x in self.bins)): |
|
223
|
|
|
# line = ' '.join(self.__str_els(b, i) for b in self.bins) |
|
224
|
|
|
|
|
225
|
|
|
def __str_els(self, bin, index): |
|
226
|
|
|
if index < len(bin): |
|
227
|
|
|
return bin[index] + ' ' * (max(len(x) for x in bin) - len(bin[index])) |
|
228
|
|
|
return ' ' * self.__max_el_length |
|
229
|
|
|
|
|
230
|
|
|
|
|
231
|
|
|
def _check_design(instance, attribute, value): |
|
232
|
|
|
if not 0 < len(value) < 100: |
|
233
|
|
|
raise ValueError("Resulted in {} bins but they should be in [1,100]".format(len(value)+1)) |
|
234
|
|
|
for i in range(1, len(value)): |
|
235
|
|
|
if value[i] <= value[i - 1]: |
|
236
|
|
|
raise ValueError("Invalid design list. Each element should be greater than the previous one; prev: {}, current: {}".format(value[i - 1], value[i])) |
|
237
|
|
|
|
|
238
|
|
|
|
|
239
|
|
|
@attr.s(cmp=True, repr=True, str=True) |
|
240
|
|
|
class BinDesign(object): |
|
241
|
|
|
seps = attr.ib(init=True, converter=list, validator=_check_design, cmp=True, repr=True) |
|
242
|
|
|
|
|
243
|
|
|
def ranges(self, nb_elements): |
|
244
|
|
|
if self.seps[-1] >= nb_elements: |
|
245
|
|
|
raise ValueError("Last bin starts from index {}, but requested to build range indices for {} elements".format(self.seps[-1], nb_elements)) |
|
246
|
|
|
yield range(self.seps[0]) |
|
247
|
|
|
for i, item in enumerate(self.seps[1:]): |
|
248
|
|
|
yield range(self.seps[i], item) |
|
249
|
|
|
yield range(self.seps[-1], nb_elements) |
|
250
|
|
|
|
|
251
|
|
|
def __getitem__(self, item): |
|
252
|
|
|
return self.seps[item] |
|
253
|
|
|
def __len__(self): |
|
254
|
|
|
return len(self.seps) |
|
255
|
|
|
|
|
256
|
|
|
|
|
257
|
|
|
|
|
258
|
|
|
|
|
259
|
|
|
@attr.s |
|
260
|
|
|
class Population(object): |
|
261
|
|
|
psm = attr.ib(init=True) |
|
262
|
|
|
# doc_ids = attr.ib(init=True, converter=list) |
|
263
|
|
|
# outlet_id2name = attr.ib(init=False, default=attr.Factory(lambda self: OrderedDict([(_id, name) for name, _id in self.psm.scale.items()]), takes_self=True), repr=False) |
|
264
|
|
|
pool = attr.ib(init=False, repr=True, cmp=True) |
|
265
|
|
|
_nb_items_to_bin = attr.ib(init=False, default=attr.Factory(lambda self: len(self.psm.scale), takes_self=True)) |
|
266
|
|
|
|
|
267
|
|
|
def create_random(self, vector_length, scale_length): |
|
268
|
|
|
inds = [_ for _ in range(1, scale_length)] |
|
269
|
|
|
res = [] |
|
270
|
|
|
for i in range(vector_length): |
|
271
|
|
|
c = random.choice(inds) |
|
272
|
|
|
inds.remove(c) |
|
273
|
|
|
res.append(c) |
|
274
|
|
|
if res[-1] == scale_length: |
|
275
|
|
|
raise RuntimeError("Vector's last element should be maximum {}. Found {}".format(scale_length-1, res[-1])) |
|
276
|
|
|
return sorted(res, reverse=False) |
|
277
|
|
|
|
|
278
|
|
|
def distr(self, design): |
|
279
|
|
|
ds = DiscreetizationScheme.from_design(design, list(self.psm.scale.items())) |
|
280
|
|
|
outlet_id2name = OrderedDict([(_id, name) for name, _id in self.psm.scale.items()]) |
|
281
|
|
|
datapoint_labels = [ds.poster_name2ideology_label[outlet_id2name[x]] for x in self.datapoint_ids] |
|
282
|
|
|
c = Counter(datapoint_labels) |
|
283
|
|
|
n = sum(c.values()) |
|
284
|
|
|
return [c[class_name] / float(n) for class_name, _ in ds] |
|
285
|
|
|
|
|
286
|
|
|
def compute_fitness(self, design): |
|
287
|
|
|
"""The smaller the better in our case!""" |
|
288
|
|
|
design.fitness = jensen_shannon_distance(self.distr(design), self.ideal) |
|
289
|
|
|
return design.fitness |
|
290
|
|
|
|
|
291
|
|
|
def init_random(self, datapoint_ids, pool_size, vector_length): |
|
292
|
|
|
self.datapoint_ids = datapoint_ids |
|
293
|
|
|
self.psm.datapoint_ids = datapoint_ids |
|
294
|
|
|
# self._nb_items_to_bin = elements |
|
295
|
|
|
self.pool = [BinDesign(self.create_random(vector_length, self._nb_items_to_bin)) for _ in range(pool_size)] |
|
296
|
|
|
self.ideal = [float(1)/(vector_length+1)] * (vector_length + 1) |
|
297
|
|
|
self.sorted = False |
|
298
|
|
|
self._generation_counter = 0 |
|
299
|
|
|
|
|
300
|
|
|
def selection(self): |
|
301
|
|
|
self._inds = [x for x in range(len(self.pool))] |
|
302
|
|
|
|
|
303
|
|
|
def operators(self, prob=0.2): |
|
304
|
|
|
# MUTATE |
|
305
|
|
|
self._new = [_f for _f in [self.mutate(design, self._nb_items_to_bin, prob=prob) for design in self.pool] if _f] |
|
306
|
|
|
|
|
307
|
|
|
def mutate(self, design, elements, prob=0.2): |
|
308
|
|
|
""" |
|
309
|
|
|
:param design: |
|
310
|
|
|
:param elements: |
|
311
|
|
|
:param prob: |
|
312
|
|
|
:return: |
|
313
|
|
|
""" |
|
314
|
|
|
# available = [x for x in range(1, elements) if x not in design] |
|
315
|
|
|
_ = BinDesign([x for x in self._gen_genes(design, elements, prob=prob)]) |
|
316
|
|
|
if _ == design: |
|
317
|
|
|
return None |
|
318
|
|
|
return _ |
|
319
|
|
|
# res = [] |
|
320
|
|
|
# for i in range(len(res)): |
|
321
|
|
|
# if random.random() < prob: |
|
322
|
|
|
# available = [x for x in range(1, elements) if (x not in design and x not in res)] |
|
323
|
|
|
# assert available[-1] == elements - 1 |
|
324
|
|
|
# c = random.choice(available) |
|
325
|
|
|
# res.append(c) |
|
326
|
|
|
# _ = BinDesign(sorted(res, reverse=False)) |
|
327
|
|
|
# if _[-1] == elements: |
|
328
|
|
|
# raise RuntimeError("Vector: {}, scale elements: {}".format(list(_), elements)) |
|
329
|
|
|
# return _ |
|
330
|
|
|
def replacement(self): |
|
331
|
|
|
self.pool = sorted(self.pool + self._new, key=lambda x: getattr(x, 'fitness', self.compute_fitness(x)))[:len(self.pool)] # sors from smaller to bigger values so best to worst because smaller fitness value the better |
|
332
|
|
|
self._generation_counter += 1 |
|
333
|
|
|
self.sorted = True |
|
334
|
|
|
|
|
335
|
|
|
def evolve(self, nb_generations, prob=0.35): |
|
336
|
|
|
self._init_condition(nb_generations) |
|
337
|
|
|
while not self.condition(): |
|
338
|
|
|
self.selection() |
|
339
|
|
|
self.operators(prob=prob) |
|
340
|
|
|
self.replacement() |
|
341
|
|
|
|
|
342
|
|
|
def _init_condition(self, nb_generations): |
|
343
|
|
|
self._max_generation = self._generation_counter + nb_generations |
|
344
|
|
|
|
|
345
|
|
|
def condition(self): |
|
346
|
|
|
"""Until convergence or max generation count. Returns True when evolution should stop""" |
|
347
|
|
|
return self._generation_counter >= self._max_generation |
|
348
|
|
|
|
|
349
|
|
|
def _gen_genes(self, design, elements, prob=0.5): |
|
350
|
|
|
if len(design) == 1: |
|
351
|
|
|
yield self._toss(design[0], self._new_index([x for x in range(1, elements)], design[0]), prob, random.random()) |
|
352
|
|
|
elif len(design) == 2: |
|
353
|
|
|
prev = self._toss(design[0], self._new_index([x for x in range(1, design[1])], design[0]), prob, random.random()) |
|
354
|
|
|
yield prev |
|
355
|
|
|
yield self._toss(design[1], self._new_index([x for x in range(prev + 1, elements)], design[1]), prob, random.random()) |
|
356
|
|
|
else: |
|
357
|
|
|
prev = self._toss(design[0], self._new_index([x for x in range(1, design[1])], design[0]), prob, |
|
358
|
|
|
random.random()) |
|
359
|
|
|
yield prev |
|
360
|
|
|
for i in range(1, len(design) - 1): |
|
361
|
|
|
_ = self._toss(design[i], self._new_index([x for x in range(prev+1, design[i+1])], design[i]), prob, random.random()) |
|
362
|
|
|
yield _ |
|
363
|
|
|
prev = _ |
|
364
|
|
|
yield self._toss(design[-1], self._new_index([x for x in range(prev + 1, elements)], design[-1]), prob, random.random()) |
|
365
|
|
|
|
|
366
|
|
|
def _new_index(self, indices_list, current_index): |
|
367
|
|
|
if len(indices_list) == 1: |
|
368
|
|
|
return current_index |
|
369
|
|
|
c = random.randint(1, len(indices_list)-1) - 1 |
|
370
|
|
|
if c >= indices_list.index(current_index): |
|
371
|
|
|
return indices_list[c+1] |
|
372
|
|
|
return indices_list[c] |
|
373
|
|
|
|
|
374
|
|
|
def _toss(self, v1, v2, prob, toss): |
|
375
|
|
|
if toss < prob: |
|
376
|
|
|
return v2 |
|
377
|
|
|
return v1 |
|
378
|
|
|
|
|
379
|
|
|
|
|
380
|
|
|
def jensen_shannon_distance(p, q): |
|
381
|
|
|
"""Jenson-Shannon Distance between two probability distributions""" |
|
382
|
|
|
p = np.array(p) |
|
383
|
|
|
q = np.array(q) |
|
384
|
|
|
m = (p + q) / 2 |
|
385
|
|
|
divergence = (scipy_entropy(p, m) + scipy_entropy(q, m)) / 2 |
|
386
|
|
|
distance = np.sqrt(divergence) |
|
387
|
|
|
return distance |
|
388
|
|
|
|