|
1
|
|
|
# Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
|
2
|
|
|
# |
|
3
|
|
|
# Redistribution and use in source and binary forms, with or without |
|
4
|
|
|
# modification, are permitted provided that the following conditions |
|
5
|
|
|
# are met: |
|
6
|
|
|
# |
|
7
|
|
|
# - Redistributions of source code must retain the above copyright |
|
8
|
|
|
# notice, this list of conditions and the following disclaimer. |
|
9
|
|
|
# - Redistributions in binary form must reproduce the above copyright |
|
10
|
|
|
# notice, this list of conditions and the following disclaimer in the |
|
11
|
|
|
# documentation and/or other materials provided with the distribution. |
|
12
|
|
|
# - Neither the name of Salesforce.com nor the names of its contributors |
|
13
|
|
|
# may be used to endorse or promote products derived from this |
|
14
|
|
|
# software without specific prior written permission. |
|
15
|
|
|
# |
|
16
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|
17
|
|
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|
18
|
|
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
|
19
|
|
|
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
|
20
|
|
|
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
|
21
|
|
|
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
|
22
|
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
|
23
|
|
|
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
|
24
|
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
|
25
|
|
|
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
|
26
|
|
|
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
27
|
|
|
|
|
28
|
|
|
import os |
|
29
|
|
|
import shutil |
|
30
|
|
|
import numpy |
|
31
|
|
|
import scipy |
|
32
|
|
|
import scipy.misc |
|
33
|
|
|
import scipy.ndimage |
|
34
|
|
|
from distributions.dbg.random import sample_discrete, sample_discrete_log |
|
35
|
|
|
from distributions.lp.models import nich |
|
36
|
|
|
from distributions.lp.clustering import PitmanYor |
|
37
|
|
|
from distributions.lp.mixture import MixtureIdTracker |
|
38
|
|
|
from distributions.io.stream import json_stream_load, json_stream_dump |
|
39
|
|
|
from multiprocessing import Process |
|
40
|
|
|
import parsable |
|
41
|
|
|
parsable = parsable.Parsable() |
|
42
|
|
|
|
|
43
|
|
|
|
|
44
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
45
|
|
|
DATA = os.path.join(ROOT, 'data') |
|
46
|
|
|
RESULTS = os.path.join(ROOT, 'results') |
|
47
|
|
|
SAMPLES = os.path.join(DATA, 'samples.json.gz') |
|
48
|
|
|
IMAGE = scipy.misc.imread(os.path.join(ROOT, 'fox.png')) |
|
49
|
|
|
SAMPLE_COUNT = 10000 |
|
50
|
|
|
PASSES = 10 |
|
51
|
|
|
EMPTY_GROUP_COUNT = 10 |
|
52
|
|
|
|
|
53
|
|
|
|
|
54
|
|
|
for dirname in [DATA, RESULTS]: |
|
55
|
|
|
if not os.path.exists(dirname): |
|
56
|
|
|
os.makedirs(dirname) |
|
57
|
|
|
|
|
58
|
|
|
|
|
59
|
|
|
class ImageModel(object): |
|
60
|
|
|
def __init__(self): |
|
61
|
|
|
self.clustering = PitmanYor.from_dict({ |
|
62
|
|
|
'alpha': 100.0, |
|
63
|
|
|
'd': 0.1, |
|
64
|
|
|
}) |
|
65
|
|
|
self.feature = nich.Shared.from_dict({ |
|
66
|
|
|
'mu': 0.0, |
|
67
|
|
|
'kappa': 0.1, |
|
68
|
|
|
'sigmasq': 0.01, |
|
69
|
|
|
'nu': 1.0, |
|
70
|
|
|
}) |
|
71
|
|
|
|
|
72
|
|
|
class Mixture(object): |
|
73
|
|
|
def __init__(self): |
|
74
|
|
|
self.clustering = PitmanYor.Mixture() |
|
75
|
|
|
self.feature_x = nich.Mixture() |
|
76
|
|
|
self.feature_y = nich.Mixture() |
|
77
|
|
|
self.id_tracker = MixtureIdTracker() |
|
78
|
|
|
|
|
79
|
|
|
def __len__(self): |
|
80
|
|
|
return len(self.clustering) |
|
81
|
|
|
|
|
82
|
|
|
def init(self, model, empty_group_count=EMPTY_GROUP_COUNT): |
|
83
|
|
|
assert empty_group_count >= 1 |
|
84
|
|
|
counts = [0] * empty_group_count |
|
85
|
|
|
self.clustering.init(model.clustering, counts) |
|
86
|
|
|
assert len(self.clustering) == len(counts) |
|
87
|
|
|
self.id_tracker.init(len(counts)) |
|
88
|
|
|
|
|
89
|
|
|
self.feature_x.clear() |
|
90
|
|
|
self.feature_y.clear() |
|
91
|
|
|
for _ in xrange(empty_group_count): |
|
92
|
|
|
self.feature_x.add_group(model.feature) |
|
93
|
|
|
self.feature_y.add_group(model.feature) |
|
94
|
|
|
self.feature_x.init(model.feature) |
|
95
|
|
|
self.feature_y.init(model.feature) |
|
96
|
|
|
|
|
97
|
|
|
def score_value(self, model, xy, scores): |
|
98
|
|
|
x, y = xy |
|
99
|
|
|
self.clustering.score_value(model.clustering, scores) |
|
100
|
|
|
self.feature_x.score_value(model.feature, x, scores) |
|
101
|
|
|
self.feature_y.score_value(model.feature, y, scores) |
|
102
|
|
|
|
|
103
|
|
|
def add_value(self, model, groupid, xy): |
|
104
|
|
|
x, y = xy |
|
105
|
|
|
group_added = self.clustering.add_value(model.clustering, groupid) |
|
106
|
|
|
self.feature_x.add_value(model.feature, groupid, x) |
|
107
|
|
|
self.feature_y.add_value(model.feature, groupid, y) |
|
108
|
|
|
if group_added: |
|
109
|
|
|
self.feature_x.add_group(model.feature) |
|
110
|
|
|
self.feature_y.add_group(model.feature) |
|
111
|
|
|
self.id_tracker.add_group() |
|
112
|
|
|
|
|
113
|
|
|
def remove_value(self, model, groupid, xy): |
|
114
|
|
|
x, y = xy |
|
115
|
|
|
group_removeed = self.clustering.remove_value( |
|
116
|
|
|
model.clustering, |
|
117
|
|
|
groupid) |
|
118
|
|
|
self.feature_x.remove_value(model.feature, groupid, x) |
|
119
|
|
|
self.feature_y.remove_value(model.feature, groupid, y) |
|
120
|
|
|
if group_removeed: |
|
121
|
|
|
self.feature_x.remove_group(model.feature, groupid) |
|
122
|
|
|
self.feature_y.remove_group(model.feature, groupid) |
|
123
|
|
|
self.id_tracker.remove_group(groupid) |
|
124
|
|
|
|
|
125
|
|
|
|
|
126
|
|
|
def sample_from_image(image, sample_count): |
|
127
|
|
|
image = -1.0 * image |
|
128
|
|
|
image -= image.min() |
|
129
|
|
|
x_pmf = image.sum(axis=1) |
|
130
|
|
|
y_pmfs = image.copy() |
|
131
|
|
|
for y_pmf in y_pmfs: |
|
132
|
|
|
y_pmf /= (y_pmf.sum() + 1e-8) |
|
133
|
|
|
|
|
134
|
|
|
x_scale = 2.0 / (image.shape[0] - 1) |
|
135
|
|
|
y_scale = 2.0 / (image.shape[1] - 1) |
|
136
|
|
|
|
|
137
|
|
|
for _ in xrange(sample_count): |
|
138
|
|
|
x = sample_discrete(x_pmf) |
|
139
|
|
|
y = sample_discrete(y_pmfs[x]) |
|
140
|
|
|
yield (x * x_scale - 1.0, y * y_scale - 1.0) |
|
141
|
|
|
|
|
142
|
|
|
|
|
143
|
|
|
def synthesize_image(model, mixture): |
|
144
|
|
|
width, height = IMAGE.shape |
|
145
|
|
|
image = numpy.zeros((width, height)) |
|
146
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
|
147
|
|
|
x_scale = 2.0 / (width - 1) |
|
148
|
|
|
y_scale = 2.0 / (height - 1) |
|
149
|
|
|
for x in xrange(width): |
|
150
|
|
|
for y in xrange(height): |
|
151
|
|
|
xy = (x * x_scale - 1.0, y * y_scale - 1.0) |
|
152
|
|
|
mixture.score_value(model, xy, scores) |
|
153
|
|
|
prob = numpy.exp(scores, out=scores).sum() |
|
154
|
|
|
image[x, y] = prob |
|
155
|
|
|
|
|
156
|
|
|
image /= image.max() |
|
157
|
|
|
image -= 1.0 |
|
158
|
|
|
image *= -255 |
|
159
|
|
|
return image.astype(numpy.uint8) |
|
160
|
|
|
|
|
161
|
|
|
|
|
162
|
|
|
def visualize_dataset(samples): |
|
163
|
|
|
width, height = IMAGE.shape |
|
164
|
|
|
x_scale = 2.0 / (width - 1) |
|
165
|
|
|
y_scale = 2.0 / (height - 1) |
|
166
|
|
|
image = numpy.zeros((width, height)) |
|
167
|
|
|
for x, y in samples: |
|
168
|
|
|
x = int(round((x + 1.0) / x_scale)) |
|
169
|
|
|
y = int(round((y + 1.0) / y_scale)) |
|
170
|
|
|
image[x, y] += 1 |
|
171
|
|
|
image = scipy.ndimage.gaussian_filter(image, sigma=1) |
|
172
|
|
|
image *= -255.0 / image.max() |
|
173
|
|
|
image -= image.min() |
|
174
|
|
|
return image.astype(numpy.uint8) |
|
175
|
|
|
|
|
176
|
|
|
|
|
177
|
|
|
@parsable.command |
|
178
|
|
|
def create_dataset(sample_count=SAMPLE_COUNT): |
|
179
|
|
|
''' |
|
180
|
|
|
Extract dataset from image. |
|
181
|
|
|
''' |
|
182
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'original.png'), IMAGE) |
|
183
|
|
|
print 'sampling {} points from image'.format(sample_count) |
|
184
|
|
|
samples = sample_from_image(IMAGE, sample_count) |
|
185
|
|
|
json_stream_dump(samples, SAMPLES) |
|
186
|
|
|
image = visualize_dataset(json_stream_load(SAMPLES)) |
|
187
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'samples.png'), image) |
|
188
|
|
|
|
|
189
|
|
|
|
|
190
|
|
|
@parsable.command |
|
191
|
|
|
def compress_sequential(): |
|
192
|
|
|
''' |
|
193
|
|
|
Compress image via sequential initialization. |
|
194
|
|
|
''' |
|
195
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
|
196
|
|
|
print 'sequential start' |
|
197
|
|
|
model = ImageModel() |
|
198
|
|
|
mixture = ImageModel.Mixture() |
|
199
|
|
|
mixture.init(model) |
|
200
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
|
201
|
|
|
|
|
202
|
|
|
for xy in json_stream_load(SAMPLES): |
|
203
|
|
|
scores.resize(len(mixture)) |
|
204
|
|
|
mixture.score_value(model, xy, scores) |
|
205
|
|
|
groupid = sample_discrete_log(scores) |
|
206
|
|
|
mixture.add_value(model, groupid, xy) |
|
207
|
|
|
|
|
208
|
|
|
print 'sequential found {} components'.format(len(mixture)) |
|
209
|
|
|
image = synthesize_image(model, mixture) |
|
210
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'sequential.png'), image) |
|
211
|
|
|
|
|
212
|
|
|
|
|
213
|
|
View Code Duplication |
@parsable.command |
|
|
|
|
|
|
214
|
|
|
def compress_gibbs(passes=PASSES): |
|
215
|
|
|
''' |
|
216
|
|
|
Compress image via gibbs sampling. |
|
217
|
|
|
''' |
|
218
|
|
|
assert passes >= 0 |
|
219
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
|
220
|
|
|
print 'prior+gibbs start {} passes'.format(passes) |
|
221
|
|
|
model = ImageModel() |
|
222
|
|
|
mixture = ImageModel.Mixture() |
|
223
|
|
|
mixture.init(model) |
|
224
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
|
225
|
|
|
assignments = {} |
|
226
|
|
|
|
|
227
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
228
|
|
|
scores.resize(len(mixture)) |
|
229
|
|
|
mixture.clustering.score_value(model.clustering, scores) |
|
230
|
|
|
groupid = sample_discrete_log(scores) |
|
231
|
|
|
mixture.add_value(model, groupid, xy) |
|
232
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
233
|
|
|
|
|
234
|
|
|
print 'prior+gibbs init with {} components'.format(len(mixture)) |
|
235
|
|
|
|
|
236
|
|
|
for _ in xrange(passes): |
|
237
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
238
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
|
239
|
|
|
mixture.remove_value(model, groupid, xy) |
|
240
|
|
|
scores.resize(len(mixture)) |
|
241
|
|
|
mixture.score_value(model, xy, scores) |
|
242
|
|
|
groupid = sample_discrete_log(scores) |
|
243
|
|
|
mixture.add_value(model, groupid, xy) |
|
244
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
245
|
|
|
|
|
246
|
|
|
print 'prior+gibbs found {} components'.format(len(mixture)) |
|
247
|
|
|
image = synthesize_image(model, mixture) |
|
248
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image) |
|
249
|
|
|
|
|
250
|
|
|
|
|
251
|
|
View Code Duplication |
@parsable.command |
|
|
|
|
|
|
252
|
|
|
def compress_seq_gibbs(passes=PASSES): |
|
253
|
|
|
''' |
|
254
|
|
|
Compress image via sequentiall-initialized gibbs sampling. |
|
255
|
|
|
''' |
|
256
|
|
|
assert passes >= 1 |
|
257
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
|
258
|
|
|
print 'seq+gibbs start {} passes'.format(passes) |
|
259
|
|
|
model = ImageModel() |
|
260
|
|
|
mixture = ImageModel.Mixture() |
|
261
|
|
|
mixture.init(model) |
|
262
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
|
263
|
|
|
assignments = {} |
|
264
|
|
|
|
|
265
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
266
|
|
|
scores.resize(len(mixture)) |
|
267
|
|
|
mixture.score_value(model, xy, scores) |
|
268
|
|
|
groupid = sample_discrete_log(scores) |
|
269
|
|
|
mixture.add_value(model, groupid, xy) |
|
270
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
271
|
|
|
|
|
272
|
|
|
print 'seq+gibbs init with {} components'.format(len(mixture)) |
|
273
|
|
|
|
|
274
|
|
|
for _ in xrange(passes - 1): |
|
275
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
276
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
|
277
|
|
|
mixture.remove_value(model, groupid, xy) |
|
278
|
|
|
scores.resize(len(mixture)) |
|
279
|
|
|
mixture.score_value(model, xy, scores) |
|
280
|
|
|
groupid = sample_discrete_log(scores) |
|
281
|
|
|
mixture.add_value(model, groupid, xy) |
|
282
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
283
|
|
|
|
|
284
|
|
|
print 'seq+gibbs found {} components'.format(len(mixture)) |
|
285
|
|
|
image = synthesize_image(model, mixture) |
|
286
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'seq_gibbs.png'), image) |
|
287
|
|
|
|
|
288
|
|
|
|
|
289
|
|
|
def json_loop_load(filename): |
|
290
|
|
|
while True: |
|
291
|
|
|
for i, item in enumerate(json_stream_load(filename)): |
|
292
|
|
|
yield i, item |
|
293
|
|
|
|
|
294
|
|
|
|
|
295
|
|
|
def annealing_schedule(passes): |
|
296
|
|
|
passes = float(passes) |
|
297
|
|
|
assert passes >= 1 |
|
298
|
|
|
add_rate = passes |
|
299
|
|
|
remove_rate = passes - 1 |
|
300
|
|
|
state = add_rate |
|
301
|
|
|
while True: |
|
302
|
|
|
if state >= 0: |
|
303
|
|
|
state -= remove_rate |
|
304
|
|
|
yield True |
|
305
|
|
|
else: |
|
306
|
|
|
state += add_rate |
|
307
|
|
|
yield False |
|
308
|
|
|
|
|
309
|
|
|
|
|
310
|
|
|
@parsable.command |
|
311
|
|
|
def compress_annealing(passes=PASSES): |
|
312
|
|
|
''' |
|
313
|
|
|
Compress image via subsample annealing. |
|
314
|
|
|
''' |
|
315
|
|
|
assert passes >= 1 |
|
316
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
|
317
|
|
|
print 'annealing start {} passes'.format(passes) |
|
318
|
|
|
model = ImageModel() |
|
319
|
|
|
mixture = ImageModel.Mixture() |
|
320
|
|
|
mixture.init(model) |
|
321
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
|
322
|
|
|
assignments = {} |
|
323
|
|
|
|
|
324
|
|
|
to_add = json_loop_load(SAMPLES) |
|
325
|
|
|
to_remove = json_loop_load(SAMPLES) |
|
326
|
|
|
|
|
327
|
|
|
for next_action_is_add in annealing_schedule(passes): |
|
328
|
|
|
if next_action_is_add: |
|
329
|
|
|
i, xy = to_add.next() |
|
330
|
|
|
if i in assignments: |
|
331
|
|
|
break |
|
332
|
|
|
scores.resize(len(mixture)) |
|
333
|
|
|
mixture.score_value(model, xy, scores) |
|
334
|
|
|
groupid = sample_discrete_log(scores) |
|
335
|
|
|
mixture.add_value(model, groupid, xy) |
|
336
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
337
|
|
|
else: |
|
338
|
|
|
i, xy = to_remove.next() |
|
339
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments.pop(i)) |
|
340
|
|
|
mixture.remove_value(model, groupid, xy) |
|
341
|
|
|
|
|
342
|
|
|
print 'annealing found {} components'.format(len(mixture)) |
|
343
|
|
|
image = synthesize_image(model, mixture) |
|
344
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'annealing.png'), image) |
|
345
|
|
|
|
|
346
|
|
|
|
|
347
|
|
|
@parsable.command |
|
348
|
|
|
def clean(): |
|
349
|
|
|
''' |
|
350
|
|
|
Clean out dataset and results. |
|
351
|
|
|
''' |
|
352
|
|
|
for dirname in [DATA, RESULTS]: |
|
353
|
|
|
if not os.path.exists(dirname): |
|
354
|
|
|
shutil.rmtree(dirname) |
|
355
|
|
|
|
|
356
|
|
|
|
|
357
|
|
|
@parsable.command |
|
358
|
|
|
def run(sample_count=SAMPLE_COUNT, passes=PASSES): |
|
359
|
|
|
''' |
|
360
|
|
|
Generate all datasets and run all algorithms. |
|
361
|
|
|
See index.html for results. |
|
362
|
|
|
''' |
|
363
|
|
|
create_dataset(sample_count) |
|
364
|
|
|
|
|
365
|
|
|
procs = [ |
|
366
|
|
|
Process(target=compress_sequential), |
|
367
|
|
|
Process(target=compress_gibbs, args=(passes,)), |
|
368
|
|
|
Process(target=compress_annealing, args=(passes,)), |
|
369
|
|
|
Process(target=compress_seq_gibbs, args=(passes,)), |
|
370
|
|
|
] |
|
371
|
|
|
for proc in procs: |
|
372
|
|
|
proc.start() |
|
373
|
|
|
for proc in procs: |
|
374
|
|
|
proc.join() |
|
375
|
|
|
|
|
376
|
|
|
if __name__ == '__main__': |
|
377
|
|
|
parsable.dispatch() |
|
378
|
|
|
|