1 | # Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
||
2 | # |
||
3 | # Redistribution and use in source and binary forms, with or without |
||
4 | # modification, are permitted provided that the following conditions |
||
5 | # are met: |
||
6 | # |
||
7 | # - Redistributions of source code must retain the above copyright |
||
8 | # notice, this list of conditions and the following disclaimer. |
||
9 | # - Redistributions in binary form must reproduce the above copyright |
||
10 | # notice, this list of conditions and the following disclaimer in the |
||
11 | # documentation and/or other materials provided with the distribution. |
||
12 | # - Neither the name of Salesforce.com nor the names of its contributors |
||
13 | # may be used to endorse or promote products derived from this |
||
14 | # software without specific prior written permission. |
||
15 | # |
||
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
||
17 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
||
18 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
||
19 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
||
20 | # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
||
21 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
||
22 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
||
23 | # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
||
24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
||
25 | # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
||
26 | # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
||
27 | |||
28 | import os |
||
29 | import shutil |
||
30 | import numpy |
||
31 | import scipy |
||
32 | import scipy.misc |
||
33 | import scipy.ndimage |
||
34 | from distributions.dbg.random import sample_discrete, sample_discrete_log |
||
35 | from distributions.lp.models import nich |
||
36 | from distributions.lp.clustering import PitmanYor |
||
37 | from distributions.lp.mixture import MixtureIdTracker |
||
38 | from distributions.io.stream import json_stream_load, json_stream_dump |
||
39 | from multiprocessing import Process |
||
40 | import parsable |
||
41 | parsable = parsable.Parsable() |
||
42 | |||
43 | |||
44 | ROOT = os.path.dirname(os.path.abspath(__file__)) |
||
45 | DATA = os.path.join(ROOT, 'data') |
||
46 | RESULTS = os.path.join(ROOT, 'results') |
||
47 | SAMPLES = os.path.join(DATA, 'samples.json.gz') |
||
48 | IMAGE = scipy.misc.imread(os.path.join(ROOT, 'fox.png')) |
||
49 | SAMPLE_COUNT = 10000 |
||
50 | PASSES = 10 |
||
51 | EMPTY_GROUP_COUNT = 10 |
||
52 | |||
53 | |||
54 | for dirname in [DATA, RESULTS]: |
||
55 | if not os.path.exists(dirname): |
||
56 | os.makedirs(dirname) |
||
57 | |||
58 | |||
59 | class ImageModel(object): |
||
60 | def __init__(self): |
||
61 | self.clustering = PitmanYor.from_dict({ |
||
62 | 'alpha': 100.0, |
||
63 | 'd': 0.1, |
||
64 | }) |
||
65 | self.feature = nich.Shared.from_dict({ |
||
66 | 'mu': 0.0, |
||
67 | 'kappa': 0.1, |
||
68 | 'sigmasq': 0.01, |
||
69 | 'nu': 1.0, |
||
70 | }) |
||
71 | |||
72 | class Mixture(object): |
||
73 | def __init__(self): |
||
74 | self.clustering = PitmanYor.Mixture() |
||
75 | self.feature_x = nich.Mixture() |
||
76 | self.feature_y = nich.Mixture() |
||
77 | self.id_tracker = MixtureIdTracker() |
||
78 | |||
79 | def __len__(self): |
||
80 | return len(self.clustering) |
||
81 | |||
82 | def init(self, model, empty_group_count=EMPTY_GROUP_COUNT): |
||
83 | assert empty_group_count >= 1 |
||
84 | counts = [0] * empty_group_count |
||
85 | self.clustering.init(model.clustering, counts) |
||
86 | assert len(self.clustering) == len(counts) |
||
87 | self.id_tracker.init(len(counts)) |
||
88 | |||
89 | self.feature_x.clear() |
||
90 | self.feature_y.clear() |
||
91 | for _ in xrange(empty_group_count): |
||
92 | self.feature_x.add_group(model.feature) |
||
93 | self.feature_y.add_group(model.feature) |
||
94 | self.feature_x.init(model.feature) |
||
95 | self.feature_y.init(model.feature) |
||
96 | |||
97 | def score_value(self, model, xy, scores): |
||
98 | x, y = xy |
||
99 | self.clustering.score_value(model.clustering, scores) |
||
100 | self.feature_x.score_value(model.feature, x, scores) |
||
101 | self.feature_y.score_value(model.feature, y, scores) |
||
102 | |||
103 | def add_value(self, model, groupid, xy): |
||
104 | x, y = xy |
||
105 | group_added = self.clustering.add_value(model.clustering, groupid) |
||
106 | self.feature_x.add_value(model.feature, groupid, x) |
||
107 | self.feature_y.add_value(model.feature, groupid, y) |
||
108 | if group_added: |
||
109 | self.feature_x.add_group(model.feature) |
||
110 | self.feature_y.add_group(model.feature) |
||
111 | self.id_tracker.add_group() |
||
112 | |||
113 | def remove_value(self, model, groupid, xy): |
||
114 | x, y = xy |
||
115 | group_removeed = self.clustering.remove_value( |
||
116 | model.clustering, |
||
117 | groupid) |
||
118 | self.feature_x.remove_value(model.feature, groupid, x) |
||
119 | self.feature_y.remove_value(model.feature, groupid, y) |
||
120 | if group_removeed: |
||
121 | self.feature_x.remove_group(model.feature, groupid) |
||
122 | self.feature_y.remove_group(model.feature, groupid) |
||
123 | self.id_tracker.remove_group(groupid) |
||
124 | |||
125 | |||
126 | def sample_from_image(image, sample_count): |
||
127 | image = -1.0 * image |
||
128 | image -= image.min() |
||
129 | x_pmf = image.sum(axis=1) |
||
130 | y_pmfs = image.copy() |
||
131 | for y_pmf in y_pmfs: |
||
132 | y_pmf /= (y_pmf.sum() + 1e-8) |
||
133 | |||
134 | x_scale = 2.0 / (image.shape[0] - 1) |
||
135 | y_scale = 2.0 / (image.shape[1] - 1) |
||
136 | |||
137 | for _ in xrange(sample_count): |
||
138 | x = sample_discrete(x_pmf) |
||
139 | y = sample_discrete(y_pmfs[x]) |
||
140 | yield (x * x_scale - 1.0, y * y_scale - 1.0) |
||
141 | |||
142 | |||
143 | def synthesize_image(model, mixture): |
||
144 | width, height = IMAGE.shape |
||
145 | image = numpy.zeros((width, height)) |
||
146 | scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
||
147 | x_scale = 2.0 / (width - 1) |
||
148 | y_scale = 2.0 / (height - 1) |
||
149 | for x in xrange(width): |
||
150 | for y in xrange(height): |
||
151 | xy = (x * x_scale - 1.0, y * y_scale - 1.0) |
||
152 | mixture.score_value(model, xy, scores) |
||
153 | prob = numpy.exp(scores, out=scores).sum() |
||
154 | image[x, y] = prob |
||
155 | |||
156 | image /= image.max() |
||
157 | image -= 1.0 |
||
158 | image *= -255 |
||
159 | return image.astype(numpy.uint8) |
||
160 | |||
161 | |||
162 | def visualize_dataset(samples): |
||
163 | width, height = IMAGE.shape |
||
164 | x_scale = 2.0 / (width - 1) |
||
165 | y_scale = 2.0 / (height - 1) |
||
166 | image = numpy.zeros((width, height)) |
||
167 | for x, y in samples: |
||
168 | x = int(round((x + 1.0) / x_scale)) |
||
169 | y = int(round((y + 1.0) / y_scale)) |
||
170 | image[x, y] += 1 |
||
171 | image = scipy.ndimage.gaussian_filter(image, sigma=1) |
||
172 | image *= -255.0 / image.max() |
||
173 | image -= image.min() |
||
174 | return image.astype(numpy.uint8) |
||
175 | |||
176 | |||
177 | @parsable.command |
||
178 | def create_dataset(sample_count=SAMPLE_COUNT): |
||
179 | ''' |
||
180 | Extract dataset from image. |
||
181 | ''' |
||
182 | scipy.misc.imsave(os.path.join(RESULTS, 'original.png'), IMAGE) |
||
183 | print 'sampling {} points from image'.format(sample_count) |
||
184 | samples = sample_from_image(IMAGE, sample_count) |
||
185 | json_stream_dump(samples, SAMPLES) |
||
186 | image = visualize_dataset(json_stream_load(SAMPLES)) |
||
187 | scipy.misc.imsave(os.path.join(RESULTS, 'samples.png'), image) |
||
188 | |||
189 | |||
190 | @parsable.command |
||
191 | def compress_sequential(): |
||
192 | ''' |
||
193 | Compress image via sequential initialization. |
||
194 | ''' |
||
195 | assert os.path.exists(SAMPLES), 'first create dataset' |
||
196 | print 'sequential start' |
||
197 | model = ImageModel() |
||
198 | mixture = ImageModel.Mixture() |
||
199 | mixture.init(model) |
||
200 | scores = numpy.zeros(1, dtype=numpy.float32) |
||
201 | |||
202 | for xy in json_stream_load(SAMPLES): |
||
203 | scores.resize(len(mixture)) |
||
204 | mixture.score_value(model, xy, scores) |
||
205 | groupid = sample_discrete_log(scores) |
||
206 | mixture.add_value(model, groupid, xy) |
||
207 | |||
208 | print 'sequential found {} components'.format(len(mixture)) |
||
209 | image = synthesize_image(model, mixture) |
||
210 | scipy.misc.imsave(os.path.join(RESULTS, 'sequential.png'), image) |
||
211 | |||
212 | |||
213 | View Code Duplication | @parsable.command |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
214 | def compress_gibbs(passes=PASSES): |
||
215 | ''' |
||
216 | Compress image via gibbs sampling. |
||
217 | ''' |
||
218 | assert passes >= 0 |
||
219 | assert os.path.exists(SAMPLES), 'first create dataset' |
||
220 | print 'prior+gibbs start {} passes'.format(passes) |
||
221 | model = ImageModel() |
||
222 | mixture = ImageModel.Mixture() |
||
223 | mixture.init(model) |
||
224 | scores = numpy.zeros(1, dtype=numpy.float32) |
||
225 | assignments = {} |
||
226 | |||
227 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
||
228 | scores.resize(len(mixture)) |
||
229 | mixture.clustering.score_value(model.clustering, scores) |
||
230 | groupid = sample_discrete_log(scores) |
||
231 | mixture.add_value(model, groupid, xy) |
||
232 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
||
233 | |||
234 | print 'prior+gibbs init with {} components'.format(len(mixture)) |
||
235 | |||
236 | for _ in xrange(passes): |
||
237 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
||
238 | groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
||
239 | mixture.remove_value(model, groupid, xy) |
||
240 | scores.resize(len(mixture)) |
||
241 | mixture.score_value(model, xy, scores) |
||
242 | groupid = sample_discrete_log(scores) |
||
243 | mixture.add_value(model, groupid, xy) |
||
244 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
||
245 | |||
246 | print 'prior+gibbs found {} components'.format(len(mixture)) |
||
247 | image = synthesize_image(model, mixture) |
||
248 | scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image) |
||
249 | |||
250 | |||
251 | View Code Duplication | @parsable.command |
|
0 ignored issues
–
show
|
|||
252 | def compress_seq_gibbs(passes=PASSES): |
||
253 | ''' |
||
254 | Compress image via sequentiall-initialized gibbs sampling. |
||
255 | ''' |
||
256 | assert passes >= 1 |
||
257 | assert os.path.exists(SAMPLES), 'first create dataset' |
||
258 | print 'seq+gibbs start {} passes'.format(passes) |
||
259 | model = ImageModel() |
||
260 | mixture = ImageModel.Mixture() |
||
261 | mixture.init(model) |
||
262 | scores = numpy.zeros(1, dtype=numpy.float32) |
||
263 | assignments = {} |
||
264 | |||
265 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
||
266 | scores.resize(len(mixture)) |
||
267 | mixture.score_value(model, xy, scores) |
||
268 | groupid = sample_discrete_log(scores) |
||
269 | mixture.add_value(model, groupid, xy) |
||
270 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
||
271 | |||
272 | print 'seq+gibbs init with {} components'.format(len(mixture)) |
||
273 | |||
274 | for _ in xrange(passes - 1): |
||
275 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
||
276 | groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
||
277 | mixture.remove_value(model, groupid, xy) |
||
278 | scores.resize(len(mixture)) |
||
279 | mixture.score_value(model, xy, scores) |
||
280 | groupid = sample_discrete_log(scores) |
||
281 | mixture.add_value(model, groupid, xy) |
||
282 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
||
283 | |||
284 | print 'seq+gibbs found {} components'.format(len(mixture)) |
||
285 | image = synthesize_image(model, mixture) |
||
286 | scipy.misc.imsave(os.path.join(RESULTS, 'seq_gibbs.png'), image) |
||
287 | |||
288 | |||
289 | def json_loop_load(filename): |
||
290 | while True: |
||
291 | for i, item in enumerate(json_stream_load(filename)): |
||
292 | yield i, item |
||
293 | |||
294 | |||
295 | def annealing_schedule(passes): |
||
296 | passes = float(passes) |
||
297 | assert passes >= 1 |
||
298 | add_rate = passes |
||
299 | remove_rate = passes - 1 |
||
300 | state = add_rate |
||
301 | while True: |
||
302 | if state >= 0: |
||
303 | state -= remove_rate |
||
304 | yield True |
||
305 | else: |
||
306 | state += add_rate |
||
307 | yield False |
||
308 | |||
309 | |||
310 | @parsable.command |
||
311 | def compress_annealing(passes=PASSES): |
||
312 | ''' |
||
313 | Compress image via subsample annealing. |
||
314 | ''' |
||
315 | assert passes >= 1 |
||
316 | assert os.path.exists(SAMPLES), 'first create dataset' |
||
317 | print 'annealing start {} passes'.format(passes) |
||
318 | model = ImageModel() |
||
319 | mixture = ImageModel.Mixture() |
||
320 | mixture.init(model) |
||
321 | scores = numpy.zeros(1, dtype=numpy.float32) |
||
322 | assignments = {} |
||
323 | |||
324 | to_add = json_loop_load(SAMPLES) |
||
325 | to_remove = json_loop_load(SAMPLES) |
||
326 | |||
327 | for next_action_is_add in annealing_schedule(passes): |
||
328 | if next_action_is_add: |
||
329 | i, xy = to_add.next() |
||
330 | if i in assignments: |
||
331 | break |
||
332 | scores.resize(len(mixture)) |
||
333 | mixture.score_value(model, xy, scores) |
||
334 | groupid = sample_discrete_log(scores) |
||
335 | mixture.add_value(model, groupid, xy) |
||
336 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
||
337 | else: |
||
338 | i, xy = to_remove.next() |
||
339 | groupid = mixture.id_tracker.global_to_packed(assignments.pop(i)) |
||
340 | mixture.remove_value(model, groupid, xy) |
||
341 | |||
342 | print 'annealing found {} components'.format(len(mixture)) |
||
343 | image = synthesize_image(model, mixture) |
||
344 | scipy.misc.imsave(os.path.join(RESULTS, 'annealing.png'), image) |
||
345 | |||
346 | |||
347 | @parsable.command |
||
348 | def clean(): |
||
349 | ''' |
||
350 | Clean out dataset and results. |
||
351 | ''' |
||
352 | for dirname in [DATA, RESULTS]: |
||
353 | if not os.path.exists(dirname): |
||
354 | shutil.rmtree(dirname) |
||
355 | |||
356 | |||
357 | @parsable.command |
||
358 | def run(sample_count=SAMPLE_COUNT, passes=PASSES): |
||
359 | ''' |
||
360 | Generate all datasets and run all algorithms. |
||
361 | See index.html for results. |
||
362 | ''' |
||
363 | create_dataset(sample_count) |
||
364 | |||
365 | procs = [ |
||
366 | Process(target=compress_sequential), |
||
367 | Process(target=compress_gibbs, args=(passes,)), |
||
368 | Process(target=compress_annealing, args=(passes,)), |
||
369 | Process(target=compress_seq_gibbs, args=(passes,)), |
||
370 | ] |
||
371 | for proc in procs: |
||
372 | proc.start() |
||
373 | for proc in procs: |
||
374 | proc.join() |
||
375 | |||
376 | if __name__ == '__main__': |
||
377 | parsable.dispatch() |
||
378 |