@@ 251-286 (lines=36) @@ | ||
248 | scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image) |
|
249 | ||
250 | ||
251 | @parsable.command |
|
252 | def compress_seq_gibbs(passes=PASSES): |
|
253 | ''' |
|
254 | Compress image via sequentiall-initialized gibbs sampling. |
|
255 | ''' |
|
256 | assert passes >= 1 |
|
257 | assert os.path.exists(SAMPLES), 'first create dataset' |
|
258 | print 'seq+gibbs start {} passes'.format(passes) |
|
259 | model = ImageModel() |
|
260 | mixture = ImageModel.Mixture() |
|
261 | mixture.init(model) |
|
262 | scores = numpy.zeros(1, dtype=numpy.float32) |
|
263 | assignments = {} |
|
264 | ||
265 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
266 | scores.resize(len(mixture)) |
|
267 | mixture.score_value(model, xy, scores) |
|
268 | groupid = sample_discrete_log(scores) |
|
269 | mixture.add_value(model, groupid, xy) |
|
270 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
271 | ||
272 | print 'seq+gibbs init with {} components'.format(len(mixture)) |
|
273 | ||
274 | for _ in xrange(passes - 1): |
|
275 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
276 | groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
|
277 | mixture.remove_value(model, groupid, xy) |
|
278 | scores.resize(len(mixture)) |
|
279 | mixture.score_value(model, xy, scores) |
|
280 | groupid = sample_discrete_log(scores) |
|
281 | mixture.add_value(model, groupid, xy) |
|
282 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
283 | ||
284 | print 'seq+gibbs found {} components'.format(len(mixture)) |
|
285 | image = synthesize_image(model, mixture) |
|
286 | scipy.misc.imsave(os.path.join(RESULTS, 'seq_gibbs.png'), image) |
|
287 | ||
288 | ||
289 | def json_loop_load(filename): |
|
@@ 213-248 (lines=36) @@ | ||
210 | scipy.misc.imsave(os.path.join(RESULTS, 'sequential.png'), image) |
|
211 | ||
212 | ||
213 | @parsable.command |
|
214 | def compress_gibbs(passes=PASSES): |
|
215 | ''' |
|
216 | Compress image via gibbs sampling. |
|
217 | ''' |
|
218 | assert passes >= 0 |
|
219 | assert os.path.exists(SAMPLES), 'first create dataset' |
|
220 | print 'prior+gibbs start {} passes'.format(passes) |
|
221 | model = ImageModel() |
|
222 | mixture = ImageModel.Mixture() |
|
223 | mixture.init(model) |
|
224 | scores = numpy.zeros(1, dtype=numpy.float32) |
|
225 | assignments = {} |
|
226 | ||
227 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
228 | scores.resize(len(mixture)) |
|
229 | mixture.clustering.score_value(model.clustering, scores) |
|
230 | groupid = sample_discrete_log(scores) |
|
231 | mixture.add_value(model, groupid, xy) |
|
232 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
233 | ||
234 | print 'prior+gibbs init with {} components'.format(len(mixture)) |
|
235 | ||
236 | for _ in xrange(passes): |
|
237 | for i, xy in enumerate(json_stream_load(SAMPLES)): |
|
238 | groupid = mixture.id_tracker.global_to_packed(assignments[i]) |
|
239 | mixture.remove_value(model, groupid, xy) |
|
240 | scores.resize(len(mixture)) |
|
241 | mixture.score_value(model, xy, scores) |
|
242 | groupid = sample_discrete_log(scores) |
|
243 | mixture.add_value(model, groupid, xy) |
|
244 | assignments[i] = mixture.id_tracker.packed_to_global(groupid) |
|
245 | ||
246 | print 'prior+gibbs found {} components'.format(len(mixture)) |
|
247 | image = synthesize_image(model, mixture) |
|
248 | scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image) |
|
249 | ||
250 | ||
251 | @parsable.command |