GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.

Mixture.score_value()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
1
# Copyright (c) 2014, Salesforce.com, Inc.  All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# - Redistributions of source code must retain the above copyright
8
#   notice, this list of conditions and the following disclaimer.
9
# - Redistributions in binary form must reproduce the above copyright
10
#   notice, this list of conditions and the following disclaimer in the
11
#   documentation and/or other materials provided with the distribution.
12
# - Neither the name of Salesforce.com nor the names of its contributors
13
#   may be used to endorse or promote products derived from this
14
#   software without specific prior written permission.
15
#
16
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
22
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
23
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
25
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
26
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
28
import os
29
import shutil
30
import numpy
31
import scipy
32
import scipy.misc
33
import scipy.ndimage
34
from distributions.dbg.random import sample_discrete, sample_discrete_log
35
from distributions.lp.models import nich
36
from distributions.lp.clustering import PitmanYor
37
from distributions.lp.mixture import MixtureIdTracker
38
from distributions.io.stream import json_stream_load, json_stream_dump
39
from multiprocessing import Process
40
import parsable
41
parsable = parsable.Parsable()
42
43
44
ROOT = os.path.dirname(os.path.abspath(__file__))
45
DATA = os.path.join(ROOT, 'data')
46
RESULTS = os.path.join(ROOT, 'results')
47
SAMPLES = os.path.join(DATA, 'samples.json.gz')
48
IMAGE = scipy.misc.imread(os.path.join(ROOT, 'fox.png'))
49
SAMPLE_COUNT = 10000
50
PASSES = 10
51
EMPTY_GROUP_COUNT = 10
52
53
54
for dirname in [DATA, RESULTS]:
55
    if not os.path.exists(dirname):
56
        os.makedirs(dirname)
57
58
59
class ImageModel(object):
60
    def __init__(self):
61
        self.clustering = PitmanYor.from_dict({
62
            'alpha': 100.0,
63
            'd': 0.1,
64
        })
65
        self.feature = nich.Shared.from_dict({
66
            'mu': 0.0,
67
            'kappa': 0.1,
68
            'sigmasq': 0.01,
69
            'nu': 1.0,
70
        })
71
72
    class Mixture(object):
73
        def __init__(self):
74
            self.clustering = PitmanYor.Mixture()
75
            self.feature_x = nich.Mixture()
76
            self.feature_y = nich.Mixture()
77
            self.id_tracker = MixtureIdTracker()
78
79
        def __len__(self):
80
            return len(self.clustering)
81
82
        def init(self, model, empty_group_count=EMPTY_GROUP_COUNT):
83
            assert empty_group_count >= 1
84
            counts = [0] * empty_group_count
85
            self.clustering.init(model.clustering, counts)
86
            assert len(self.clustering) == len(counts)
87
            self.id_tracker.init(len(counts))
88
89
            self.feature_x.clear()
90
            self.feature_y.clear()
91
            for _ in xrange(empty_group_count):
92
                self.feature_x.add_group(model.feature)
93
                self.feature_y.add_group(model.feature)
94
            self.feature_x.init(model.feature)
95
            self.feature_y.init(model.feature)
96
97
        def score_value(self, model, xy, scores):
98
            x, y = xy
99
            self.clustering.score_value(model.clustering, scores)
100
            self.feature_x.score_value(model.feature, x, scores)
101
            self.feature_y.score_value(model.feature, y, scores)
102
103
        def add_value(self, model, groupid, xy):
104
            x, y = xy
105
            group_added = self.clustering.add_value(model.clustering, groupid)
106
            self.feature_x.add_value(model.feature, groupid, x)
107
            self.feature_y.add_value(model.feature, groupid, y)
108
            if group_added:
109
                self.feature_x.add_group(model.feature)
110
                self.feature_y.add_group(model.feature)
111
                self.id_tracker.add_group()
112
113
        def remove_value(self, model, groupid, xy):
114
            x, y = xy
115
            group_removeed = self.clustering.remove_value(
116
                model.clustering,
117
                groupid)
118
            self.feature_x.remove_value(model.feature, groupid, x)
119
            self.feature_y.remove_value(model.feature, groupid, y)
120
            if group_removeed:
121
                self.feature_x.remove_group(model.feature, groupid)
122
                self.feature_y.remove_group(model.feature, groupid)
123
                self.id_tracker.remove_group(groupid)
124
125
126
def sample_from_image(image, sample_count):
127
    image = -1.0 * image
128
    image -= image.min()
129
    x_pmf = image.sum(axis=1)
130
    y_pmfs = image.copy()
131
    for y_pmf in y_pmfs:
132
        y_pmf /= (y_pmf.sum() + 1e-8)
133
134
    x_scale = 2.0 / (image.shape[0] - 1)
135
    y_scale = 2.0 / (image.shape[1] - 1)
136
137
    for _ in xrange(sample_count):
138
        x = sample_discrete(x_pmf)
139
        y = sample_discrete(y_pmfs[x])
140
        yield (x * x_scale - 1.0, y * y_scale - 1.0)
141
142
143
def synthesize_image(model, mixture):
144
    width, height = IMAGE.shape
145
    image = numpy.zeros((width, height))
146
    scores = numpy.zeros(len(mixture), dtype=numpy.float32)
147
    x_scale = 2.0 / (width - 1)
148
    y_scale = 2.0 / (height - 1)
149
    for x in xrange(width):
150
        for y in xrange(height):
151
            xy = (x * x_scale - 1.0, y * y_scale - 1.0)
152
            mixture.score_value(model, xy, scores)
153
            prob = numpy.exp(scores, out=scores).sum()
154
            image[x, y] = prob
155
156
    image /= image.max()
157
    image -= 1.0
158
    image *= -255
159
    return image.astype(numpy.uint8)
160
161
162
def visualize_dataset(samples):
163
    width, height = IMAGE.shape
164
    x_scale = 2.0 / (width - 1)
165
    y_scale = 2.0 / (height - 1)
166
    image = numpy.zeros((width, height))
167
    for x, y in samples:
168
        x = int(round((x + 1.0) / x_scale))
169
        y = int(round((y + 1.0) / y_scale))
170
        image[x, y] += 1
171
    image = scipy.ndimage.gaussian_filter(image, sigma=1)
172
    image *= -255.0 / image.max()
173
    image -= image.min()
174
    return image.astype(numpy.uint8)
175
176
177
@parsable.command
178
def create_dataset(sample_count=SAMPLE_COUNT):
179
    '''
180
    Extract dataset from image.
181
    '''
182
    scipy.misc.imsave(os.path.join(RESULTS, 'original.png'), IMAGE)
183
    print 'sampling {} points from image'.format(sample_count)
184
    samples = sample_from_image(IMAGE, sample_count)
185
    json_stream_dump(samples, SAMPLES)
186
    image = visualize_dataset(json_stream_load(SAMPLES))
187
    scipy.misc.imsave(os.path.join(RESULTS, 'samples.png'), image)
188
189
190
@parsable.command
191
def compress_sequential():
192
    '''
193
    Compress image via sequential initialization.
194
    '''
195
    assert os.path.exists(SAMPLES), 'first create dataset'
196
    print 'sequential start'
197
    model = ImageModel()
198
    mixture = ImageModel.Mixture()
199
    mixture.init(model)
200
    scores = numpy.zeros(1, dtype=numpy.float32)
201
202
    for xy in json_stream_load(SAMPLES):
203
        scores.resize(len(mixture))
204
        mixture.score_value(model, xy, scores)
205
        groupid = sample_discrete_log(scores)
206
        mixture.add_value(model, groupid, xy)
207
208
    print 'sequential found {} components'.format(len(mixture))
209
    image = synthesize_image(model, mixture)
210
    scipy.misc.imsave(os.path.join(RESULTS, 'sequential.png'), image)
211
212
213 View Code Duplication
@parsable.command
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
214
def compress_gibbs(passes=PASSES):
215
    '''
216
    Compress image via gibbs sampling.
217
    '''
218
    assert passes >= 0
219
    assert os.path.exists(SAMPLES), 'first create dataset'
220
    print 'prior+gibbs start {} passes'.format(passes)
221
    model = ImageModel()
222
    mixture = ImageModel.Mixture()
223
    mixture.init(model)
224
    scores = numpy.zeros(1, dtype=numpy.float32)
225
    assignments = {}
226
227
    for i, xy in enumerate(json_stream_load(SAMPLES)):
228
        scores.resize(len(mixture))
229
        mixture.clustering.score_value(model.clustering, scores)
230
        groupid = sample_discrete_log(scores)
231
        mixture.add_value(model, groupid, xy)
232
        assignments[i] = mixture.id_tracker.packed_to_global(groupid)
233
234
    print 'prior+gibbs init with {} components'.format(len(mixture))
235
236
    for _ in xrange(passes):
237
        for i, xy in enumerate(json_stream_load(SAMPLES)):
238
            groupid = mixture.id_tracker.global_to_packed(assignments[i])
239
            mixture.remove_value(model, groupid, xy)
240
            scores.resize(len(mixture))
241
            mixture.score_value(model, xy, scores)
242
            groupid = sample_discrete_log(scores)
243
            mixture.add_value(model, groupid, xy)
244
            assignments[i] = mixture.id_tracker.packed_to_global(groupid)
245
246
    print 'prior+gibbs found {} components'.format(len(mixture))
247
    image = synthesize_image(model, mixture)
248
    scipy.misc.imsave(os.path.join(RESULTS, 'prior_gibbs.png'), image)
249
250
251 View Code Duplication
@parsable.command
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
252
def compress_seq_gibbs(passes=PASSES):
253
    '''
254
    Compress image via sequentiall-initialized gibbs sampling.
255
    '''
256
    assert passes >= 1
257
    assert os.path.exists(SAMPLES), 'first create dataset'
258
    print 'seq+gibbs start {} passes'.format(passes)
259
    model = ImageModel()
260
    mixture = ImageModel.Mixture()
261
    mixture.init(model)
262
    scores = numpy.zeros(1, dtype=numpy.float32)
263
    assignments = {}
264
265
    for i, xy in enumerate(json_stream_load(SAMPLES)):
266
        scores.resize(len(mixture))
267
        mixture.score_value(model, xy, scores)
268
        groupid = sample_discrete_log(scores)
269
        mixture.add_value(model, groupid, xy)
270
        assignments[i] = mixture.id_tracker.packed_to_global(groupid)
271
272
    print 'seq+gibbs init with {} components'.format(len(mixture))
273
274
    for _ in xrange(passes - 1):
275
        for i, xy in enumerate(json_stream_load(SAMPLES)):
276
            groupid = mixture.id_tracker.global_to_packed(assignments[i])
277
            mixture.remove_value(model, groupid, xy)
278
            scores.resize(len(mixture))
279
            mixture.score_value(model, xy, scores)
280
            groupid = sample_discrete_log(scores)
281
            mixture.add_value(model, groupid, xy)
282
            assignments[i] = mixture.id_tracker.packed_to_global(groupid)
283
284
    print 'seq+gibbs found {} components'.format(len(mixture))
285
    image = synthesize_image(model, mixture)
286
    scipy.misc.imsave(os.path.join(RESULTS, 'seq_gibbs.png'), image)
287
288
289
def json_loop_load(filename):
290
    while True:
291
        for i, item in enumerate(json_stream_load(filename)):
292
            yield i, item
293
294
295
def annealing_schedule(passes):
296
    passes = float(passes)
297
    assert passes >= 1
298
    add_rate = passes
299
    remove_rate = passes - 1
300
    state = add_rate
301
    while True:
302
        if state >= 0:
303
            state -= remove_rate
304
            yield True
305
        else:
306
            state += add_rate
307
            yield False
308
309
310
@parsable.command
311
def compress_annealing(passes=PASSES):
312
    '''
313
    Compress image via subsample annealing.
314
    '''
315
    assert passes >= 1
316
    assert os.path.exists(SAMPLES), 'first create dataset'
317
    print 'annealing start {} passes'.format(passes)
318
    model = ImageModel()
319
    mixture = ImageModel.Mixture()
320
    mixture.init(model)
321
    scores = numpy.zeros(1, dtype=numpy.float32)
322
    assignments = {}
323
324
    to_add = json_loop_load(SAMPLES)
325
    to_remove = json_loop_load(SAMPLES)
326
327
    for next_action_is_add in annealing_schedule(passes):
328
        if next_action_is_add:
329
            i, xy = to_add.next()
330
            if i in assignments:
331
                break
332
            scores.resize(len(mixture))
333
            mixture.score_value(model, xy, scores)
334
            groupid = sample_discrete_log(scores)
335
            mixture.add_value(model, groupid, xy)
336
            assignments[i] = mixture.id_tracker.packed_to_global(groupid)
337
        else:
338
            i, xy = to_remove.next()
339
            groupid = mixture.id_tracker.global_to_packed(assignments.pop(i))
340
            mixture.remove_value(model, groupid, xy)
341
342
    print 'annealing found {} components'.format(len(mixture))
343
    image = synthesize_image(model, mixture)
344
    scipy.misc.imsave(os.path.join(RESULTS, 'annealing.png'), image)
345
346
347
@parsable.command
348
def clean():
349
    '''
350
    Clean out dataset and results.
351
    '''
352
    for dirname in [DATA, RESULTS]:
353
        if not os.path.exists(dirname):
354
            shutil.rmtree(dirname)
355
356
357
@parsable.command
358
def run(sample_count=SAMPLE_COUNT, passes=PASSES):
359
    '''
360
    Generate all datasets and run all algorithms.
361
    See index.html for results.
362
    '''
363
    create_dataset(sample_count)
364
365
    procs = [
366
        Process(target=compress_sequential),
367
        Process(target=compress_gibbs, args=(passes,)),
368
        Process(target=compress_annealing, args=(passes,)),
369
        Process(target=compress_seq_gibbs, args=(passes,)),
370
    ]
371
    for proc in procs:
372
        proc.start()
373
    for proc in procs:
374
        proc.join()
375
376
if __name__ == '__main__':
377
    parsable.dispatch()
378