1
|
|
|
# Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
2
|
|
|
# |
3
|
|
|
# Redistribution and use in source and binary forms, with or without |
4
|
|
|
# modification, are permitted provided that the following conditions |
5
|
|
|
# are met: |
6
|
|
|
# |
7
|
|
|
# - Redistributions of source code must retain the above copyright |
8
|
|
|
# notice, this list of conditions and the following disclaimer. |
9
|
|
|
# - Redistributions in binary form must reproduce the above copyright |
10
|
|
|
# notice, this list of conditions and the following disclaimer in the |
11
|
|
|
# documentation and/or other materials provided with the distribution. |
12
|
|
|
# - Neither the name of Salesforce.com nor the names of its contributors |
13
|
|
|
# may be used to endorse or promote products derived from this |
14
|
|
|
# software without specific prior written permission. |
15
|
|
|
# |
16
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
17
|
|
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
18
|
|
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
19
|
|
|
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
20
|
|
|
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
21
|
|
|
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
22
|
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
23
|
|
|
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
24
|
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
25
|
|
|
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
26
|
|
|
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
27
|
|
|
|
28
|
|
|
import os |
29
|
|
|
import shutil |
30
|
|
|
import numpy |
31
|
|
|
import scipy |
32
|
|
|
import scipy.misc |
33
|
|
|
import scipy.ndimage |
34
|
|
|
from distributions.dbg.random import sample_discrete, sample_discrete_log |
35
|
|
|
from distributions.lp.models import nich |
36
|
|
|
from distributions.lp.clustering import PitmanYor |
37
|
|
|
from distributions.lp.mixture import MixtureIdTracker |
38
|
|
|
from distributions.io.stream import json_stream_load, json_stream_dump |
39
|
|
|
from multiprocessing import Process |
40
|
|
|
import parsable |
41
|
|
|
parsable = parsable.Parsable() |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
45
|
|
|
DATA = os.path.join(ROOT, 'data') |
46
|
|
|
RESULTS = os.path.join(ROOT, 'results') |
47
|
|
|
SAMPLES = os.path.join(DATA, 'samples.json.gz') |
48
|
|
|
IMAGE = scipy.misc.imread(os.path.join(ROOT, 'fox.png')) |
49
|
|
|
SAMPLE_COUNT = 10000 |
50
|
|
|
PASSES = 10 |
51
|
|
|
EMPTY_GROUP_COUNT = 10 |
52
|
|
|
|
53
|
|
|
|
54
|
|
|
for dirname in [DATA, RESULTS]: |
55
|
|
|
if not os.path.exists(dirname): |
56
|
|
|
os.makedirs(dirname) |
57
|
|
|
|
58
|
|
|
|
59
|
|
|
class ImageModel(object): |
60
|
|
|
def __init__(self): |
61
|
|
|
self.clustering = PitmanYor.from_dict({ |
62
|
|
|
'alpha': 100.0, |
63
|
|
|
'd': 0.1, |
64
|
|
|
}) |
65
|
|
|
self.feature = nich.Shared.from_dict({ |
66
|
|
|
'mu': 0.0, |
67
|
|
|
'kappa': 0.1, |
68
|
|
|
'sigmasq': 0.01, |
69
|
|
|
'nu': 1.0, |
70
|
|
|
}) |
71
|
|
|
|
72
|
|
|
class Mixture(object): |
73
|
|
|
def __init__(self): |
74
|
|
|
self.clustering = PitmanYor.Mixture() |
75
|
|
|
self.feature_x = nich.Mixture() |
76
|
|
|
self.feature_y = nich.Mixture() |
77
|
|
|
self.id_tracker = MixtureIdTracker() |
78
|
|
|
|
79
|
|
|
def __len__(self): |
80
|
|
|
return len(self.clustering) |
81
|
|
|
|
82
|
|
|
def init(self, model, empty_group_count=EMPTY_GROUP_COUNT): |
83
|
|
|
assert empty_group_count >= 1 |
84
|
|
|
counts = [0] * empty_group_count |
85
|
|
|
self.clustering.init(model.clustering, counts) |
86
|
|
|
assert len(self.clustering) == len(counts) |
87
|
|
|
self.id_tracker.init(len(counts)) |
88
|
|
|
|
89
|
|
|
self.feature_x.clear() |
90
|
|
|
self.feature_y.clear() |
91
|
|
|
for _ in xrange(empty_group_count): |
92
|
|
|
self.feature_x.add_group(model.feature) |
93
|
|
|
self.feature_y.add_group(model.feature) |
94
|
|
|
self.feature_x.init(model.feature) |
95
|
|
|
self.feature_y.init(model.feature) |
96
|
|
|
|
97
|
|
|
def score_value(self, model, xy, scores): |
98
|
|
|
x, y = xy |
99
|
|
|
self.clustering.score_value(model.clustering, scores) |
100
|
|
|
self.feature_x.score_value(model.feature, x, scores) |
101
|
|
|
self.feature_y.score_value(model.feature, y, scores) |
102
|
|
|
|
103
|
|
|
def add_value(self, model, groupid, xy): |
104
|
|
|
x, y = xy |
105
|
|
|
group_added = self.clustering.add_value(model.clustering, groupid) |
106
|
|
|
self.feature_x.add_value(model.feature, groupid, x) |
107
|
|
|
self.feature_y.add_value(model.feature, groupid, y) |
108
|
|
|
if group_added: |
109
|
|
|
self.feature_x.add_group(model.feature) |
110
|
|
|
self.feature_y.add_group(model.feature) |
111
|
|
|
self.id_tracker.add_group() |
112
|
|
|
|
113
|
|
|
def remove_value(self, model, groupid, xy): |
114
|
|
|
x, y = xy |
115
|
|
|
group_removeed = self.clustering.remove_value( |
116
|
|
|
model.clustering, |
117
|
|
|
groupid) |
118
|
|
|
self.feature_x.remove_value(model.feature, groupid, x) |
119
|
|
|
self.feature_y.remove_value(model.feature, groupid, y) |
120
|
|
|
if group_removeed: |
121
|
|
|
self.feature_x.remove_group(model.feature, groupid) |
122
|
|
|
self.feature_y.remove_group(model.feature, groupid) |
123
|
|
|
self.id_tracker.remove_group(groupid) |
124
|
|
|
|
125
|
|
|
|
126
|
|
|
def sample_from_image(image, sample_count): |
127
|
|
|
image = -1.0 * image |
128
|
|
|
image -= image.min() |
129
|
|
|
x_pmf = image.sum(axis=1) |
130
|
|
|
y_pmfs = image.copy() |
131
|
|
|
for y_pmf in y_pmfs: |
132
|
|
|
y_pmf /= (y_pmf.sum() + 1e-8) |
133
|
|
|
|
134
|
|
|
x_scale = 2.0 / (image.shape[0] - 1) |
135
|
|
|
y_scale = 2.0 / (image.shape[1] - 1) |
136
|
|
|
|
137
|
|
|
for _ in xrange(sample_count): |
138
|
|
|
x = sample_discrete(x_pmf) |
139
|
|
|
y = sample_discrete(y_pmfs[x]) |
140
|
|
|
yield (x * x_scale - 1.0, y * y_scale - 1.0) |
141
|
|
|
|
142
|
|
|
|
143
|
|
|
def synthesize_image(model, mixture): |
144
|
|
|
width, height = IMAGE.shape |
145
|
|
|
image = numpy.zeros((width, height)) |
146
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
147
|
|
|
x_scale = 2.0 / (width - 1) |
148
|
|
|
y_scale = 2.0 / (height - 1) |
149
|
|
|
for x in xrange(width): |
150
|
|
|
for y in xrange(height): |
151
|
|
|
xy = (x * x_scale - 1.0, y * y_scale - 1.0) |
152
|
|
|
mixture.score_value(model, xy, scores) |
153
|
|
|
prob = numpy.exp(scores, out=scores).sum() |
154
|
|
|
image[x, y] = prob |
155
|
|
|
|
156
|
|
|
image /= image.max() |
157
|
|
|
image -= 1.0 |
158
|
|
|
image *= -255 |
159
|
|
|
return image.astype(numpy.uint8) |
160
|
|
|
|
161
|
|
|
|
162
|
|
|
def visualize_dataset(samples): |
163
|
|
|
width, height = IMAGE.shape |
164
|
|
|
x_scale = 2.0 / (width - 1) |
165
|
|
|
y_scale = 2.0 / (height - 1) |
166
|
|
|
image = numpy.zeros((width, height)) |
167
|
|
|
for x, y in samples: |
168
|
|
|
x = int(round((x + 1.0) / x_scale)) |
169
|
|
|
y = int(round((y + 1.0) / y_scale)) |
170
|
|
|
image[x, y] += 1 |
171
|
|
|
image = scipy.ndimage.gaussian_filter(image, sigma=1) |
172
|
|
|
image *= -255.0 / image.max() |
173
|
|
|
image -= image.min() |
174
|
|
|
return image.astype(numpy.uint8) |
175
|
|
|
|
176
|
|
|
|
177
|
|
|
@parsable.command |
178
|
|
|
def create_dataset(sample_count=SAMPLE_COUNT): |
179
|
|
|
''' |
180
|
|
|
Extract dataset from image. |
181
|
|
|
''' |
182
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'original.png'), IMAGE) |
183
|
|
|
print 'sampling {} points from image'.format(sample_count) |
184
|
|
|
samples = sample_from_image(IMAGE, sample_count) |
185
|
|
|
json_stream_dump(samples, SAMPLES) |
186
|
|
|
image = visualize_dataset(json_stream_load(SAMPLES)) |
187
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'samples.png'), image) |
188
|
|
|
|
189
|
|
|
|
190
|
|
|
@parsable.command |
191
|
|
|
def compress_sequential(): |
192
|
|
|
''' |
193
|
|
|
Compress image via sequential initialization. |
194
|
|
|
''' |
195
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
196
|
|
|
print 'sequential start' |
197
|
|
|
model = ImageModel() |
198
|
|
|
mixture = ImageModel.Mixture() |
199
|
|
|
mixture.init(model) |
200
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
201
|
|
|
|
202
|
|
|
for xy in json_stream_load(SAMPLES): |
203
|
|
|
scores.resize(len(mixture)) |
204
|
|
|
mixture.score_value(model, xy, scores) |
205
|
|
|
groupid = sample_discrete_log(scores) |
206
|
|
|
mixture.add_value(model, groupid, xy) |
207
|
|
|
|
208
|
|
|
print 'sequential found {} components'.format(len(mixture)) |
209
|
|
|
image = synthesize_image(model, mixture) |
210
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'sequential.png'), image) |
211
|
|
|
|
212
|
|
|
|
213
|
|
View Code Duplication |
@parsable.command |
|
|
|
|
214
|
|
|
def compress_gibbs(passes=PASSES): |
215
|
|
|
''' |
216
|
|
|
Compress image via gibbs sampling. |
217
|
|
|
''' |
218
|
|
|
assert passes >= 0 |
219
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
220
|
|
|
print 'prior+gibbs start {} passes'.format(passes) |
221
|
|
|
model = ImageModel() |
222
|
|
|
mixture = ImageModel.Mixture() |
223
|
|
|
mixture.init(model) |
224
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
225
|
|
|
assignments = {} |
226
|
|
|
|
227
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
228
|
|
|
scores.resize(len(mixture)) |
229
|
|
|
mixture.clustering.score_value(model.clustering, scores) |
230
|
|
|
groupid = sample_discrete_log(scores) |
231
|
|
|
mixture.add_value(model, groupid, xy) |
232
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
233
|
|
|
|
234
|
|
|
print 'prior+gibbs init with {} components'.format(len(mixture)) |
235
|
|
|
|
236
|
|
|
for _ in xrange(passes): |
237
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
238
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
239
|
|
|
mixture.remove_value(model, groupid, xy) |
240
|
|
|
scores.resize(len(mixture)) |
241
|
|
|
mixture.score_value(model, xy, scores) |
242
|
|
|
groupid = sample_discrete_log(scores) |
243
|
|
|
mixture.add_value(model, groupid, xy) |
244
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
245
|
|
|
|
246
|
|
|
print 'prior+gibbs found {} components'.format(len(mixture)) |
247
|
|
|
image = synthesize_image(model, mixture) |
248
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image) |
249
|
|
|
|
250
|
|
|
|
251
|
|
View Code Duplication |
@parsable.command |
|
|
|
|
252
|
|
|
def compress_seq_gibbs(passes=PASSES): |
253
|
|
|
''' |
254
|
|
|
Compress image via sequentiall-initialized gibbs sampling. |
255
|
|
|
''' |
256
|
|
|
assert passes >= 1 |
257
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
258
|
|
|
print 'seq+gibbs start {} passes'.format(passes) |
259
|
|
|
model = ImageModel() |
260
|
|
|
mixture = ImageModel.Mixture() |
261
|
|
|
mixture.init(model) |
262
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
263
|
|
|
assignments = {} |
264
|
|
|
|
265
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
266
|
|
|
scores.resize(len(mixture)) |
267
|
|
|
mixture.score_value(model, xy, scores) |
268
|
|
|
groupid = sample_discrete_log(scores) |
269
|
|
|
mixture.add_value(model, groupid, xy) |
270
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
271
|
|
|
|
272
|
|
|
print 'seq+gibbs init with {} components'.format(len(mixture)) |
273
|
|
|
|
274
|
|
|
for _ in xrange(passes - 1): |
275
|
|
|
for i, xy in enumerate(json_stream_load(SAMPLES)): |
276
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
277
|
|
|
mixture.remove_value(model, groupid, xy) |
278
|
|
|
scores.resize(len(mixture)) |
279
|
|
|
mixture.score_value(model, xy, scores) |
280
|
|
|
groupid = sample_discrete_log(scores) |
281
|
|
|
mixture.add_value(model, groupid, xy) |
282
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
283
|
|
|
|
284
|
|
|
print 'seq+gibbs found {} components'.format(len(mixture)) |
285
|
|
|
image = synthesize_image(model, mixture) |
286
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'seq_gibbs.png'), image) |
287
|
|
|
|
288
|
|
|
|
289
|
|
|
def json_loop_load(filename): |
290
|
|
|
while True: |
291
|
|
|
for i, item in enumerate(json_stream_load(filename)): |
292
|
|
|
yield i, item |
293
|
|
|
|
294
|
|
|
|
295
|
|
|
def annealing_schedule(passes): |
296
|
|
|
passes = float(passes) |
297
|
|
|
assert passes >= 1 |
298
|
|
|
add_rate = passes |
299
|
|
|
remove_rate = passes - 1 |
300
|
|
|
state = add_rate |
301
|
|
|
while True: |
302
|
|
|
if state >= 0: |
303
|
|
|
state -= remove_rate |
304
|
|
|
yield True |
305
|
|
|
else: |
306
|
|
|
state += add_rate |
307
|
|
|
yield False |
308
|
|
|
|
309
|
|
|
|
310
|
|
|
@parsable.command |
311
|
|
|
def compress_annealing(passes=PASSES): |
312
|
|
|
''' |
313
|
|
|
Compress image via subsample annealing. |
314
|
|
|
''' |
315
|
|
|
assert passes >= 1 |
316
|
|
|
assert os.path.exists(SAMPLES), 'first create dataset' |
317
|
|
|
print 'annealing start {} passes'.format(passes) |
318
|
|
|
model = ImageModel() |
319
|
|
|
mixture = ImageModel.Mixture() |
320
|
|
|
mixture.init(model) |
321
|
|
|
scores = numpy.zeros(1, dtype=numpy.float32) |
322
|
|
|
assignments = {} |
323
|
|
|
|
324
|
|
|
to_add = json_loop_load(SAMPLES) |
325
|
|
|
to_remove = json_loop_load(SAMPLES) |
326
|
|
|
|
327
|
|
|
for next_action_is_add in annealing_schedule(passes): |
328
|
|
|
if next_action_is_add: |
329
|
|
|
i, xy = to_add.next() |
330
|
|
|
if i in assignments: |
331
|
|
|
break |
332
|
|
|
scores.resize(len(mixture)) |
333
|
|
|
mixture.score_value(model, xy, scores) |
334
|
|
|
groupid = sample_discrete_log(scores) |
335
|
|
|
mixture.add_value(model, groupid, xy) |
336
|
|
|
assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
337
|
|
|
else: |
338
|
|
|
i, xy = to_remove.next() |
339
|
|
|
groupid = mixture.id_tracker.global_to_packed(assignments.pop(i)) |
340
|
|
|
mixture.remove_value(model, groupid, xy) |
341
|
|
|
|
342
|
|
|
print 'annealing found {} components'.format(len(mixture)) |
343
|
|
|
image = synthesize_image(model, mixture) |
344
|
|
|
scipy.misc.imsave(os.path.join(RESULTS, 'annealing.png'), image) |
345
|
|
|
|
346
|
|
|
|
347
|
|
|
@parsable.command |
348
|
|
|
def clean(): |
349
|
|
|
''' |
350
|
|
|
Clean out dataset and results. |
351
|
|
|
''' |
352
|
|
|
for dirname in [DATA, RESULTS]: |
353
|
|
|
if not os.path.exists(dirname): |
354
|
|
|
shutil.rmtree(dirname) |
355
|
|
|
|
356
|
|
|
|
357
|
|
|
@parsable.command |
358
|
|
|
def run(sample_count=SAMPLE_COUNT, passes=PASSES): |
359
|
|
|
''' |
360
|
|
|
Generate all datasets and run all algorithms. |
361
|
|
|
See index.html for results. |
362
|
|
|
''' |
363
|
|
|
create_dataset(sample_count) |
364
|
|
|
|
365
|
|
|
procs = [ |
366
|
|
|
Process(target=compress_sequential), |
367
|
|
|
Process(target=compress_gibbs, args=(passes,)), |
368
|
|
|
Process(target=compress_annealing, args=(passes,)), |
369
|
|
|
Process(target=compress_seq_gibbs, args=(passes,)), |
370
|
|
|
] |
371
|
|
|
for proc in procs: |
372
|
|
|
proc.start() |
373
|
|
|
for proc in procs: |
374
|
|
|
proc.join() |
375
|
|
|
|
376
|
|
|
if __name__ == '__main__': |
377
|
|
|
parsable.dispatch() |
378
|
|
|
|