Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

fuel.datasets.Spiral   A

Complexity

Total Complexity 1

Size/Duplication

Total Lines 64
Duplicated Lines 0 %
Metric Value
dl 0
loc 64
rs 10
wmc 1

1 Method

Rating   Name   Duplication   Size   Complexity  
A Spiral.__init__() 0 23 1
1
# -*- coding: utf-8 -*-
2
3
import numpy
4
5
from collections import OrderedDict
6
7
from fuel import config
8
from fuel.datasets import IndexableDataset
9
10
11
class Spiral(IndexableDataset):
12
    u"""Toy dataset containing points sampled from spirals on a 2d plane.
13
14
    The dataset contains 3 sources:
15
16
    * features -- the (x, y) position of the datapoints
17
    * position -- the relative position on the spiral arm
18
    * label -- the class labels (spiral arm)
19
20
    .. plot::
21
22
        from fuel.datasets.toy import Spiral
23
24
        ds = Spiral(classes=3)
25
        features, position, label = ds.get_data(None, slice(0, 500))
26
27
        plt.title("Datapoints drawn from Spiral(classes=3)")
28
        for l, m in enumerate(['o', '^', 'v']):
29
            mask = label == l
30
            plt.scatter(features[mask,0], features[mask,1],
31
                        c=position[mask], marker=m, label="label==%d"%l)
32
        plt.xlim(-1.2, 1.2)
33
        plt.ylim(-1.2, 1.2)
34
        plt.legend()
35
        plt.colorbar()
36
        plt.xlabel("features[:,0]")
37
        plt.ylabel("features[:,1]")
38
        plt.show()
39
40
    Parameters
41
    ----------
42
    num_examples : int
43
        Number of datapoints to create.
44
    classes : int
45
        Number of spiral arms.
46
    cycles : float
47
        Number of turns the arms take.
48
    noise : float
49
        Add normal distributed noise with standard deviation *noise*.
50
51
    """
52
    def __init__(self, num_examples=1000, classes=1, cycles=1., noise=0.0,
53
                 **kwargs):
54
        seed = kwargs.pop('seed', config.default_seed)
55
        rng = numpy.random.RandomState(seed)
56
        # Create dataset
57
        pos = rng.uniform(size=num_examples, low=0, high=cycles)
58
        label = rng.randint(size=num_examples, low=0, high=classes)
59
        radius = (2 * pos + 1) / 3.
60
        phase_offset = label * (2*numpy.pi) / classes
61
62
        features = numpy.zeros(shape=(num_examples, 2), dtype='float32')
63
64
        features[:, 0] = radius * numpy.sin(2*numpy.pi*pos + phase_offset)
65
        features[:, 1] = radius * numpy.cos(2*numpy.pi*pos + phase_offset)
66
        features += noise * rng.normal(size=(num_examples, 2))
67
68
        data = OrderedDict([
69
            ('features', features),
70
            ('position', pos),
71
            ('label', label),
72
        ])
73
74
        super(Spiral, self).__init__(data, **kwargs)
75
76
77
class SwissRoll(IndexableDataset):
78
    """Dataset containing points from a 3-dimensional Swiss roll.
79
80
    The dataset contains 2 sources:
81
82
    * features -- the x, y and z position of the datapoints
83
    * position -- radial and z position on the manifold
84
85
    .. plot::
86
87
        from fuel.datasets.toy import SwissRoll
88
        import mpl_toolkits.mplot3d.axes3d as p3
89
        import numpy as np
90
91
        ds = SwissRoll()
92
        features, pos = ds.get_data(None, slice(0, 1000))
93
94
        color = pos[:,0]
95
        color -= color.min()
96
        color /= color.max()
97
98
        fig = plt.figure()
99
        ax = fig.gca(projection="3d")
100
        ax.scatter(features[:,0], features[:,1], features[:,2],
101
                   'x', c=color)
102
        ax.set_xlim(-1, 1)
103
        ax.set_ylim(-1, 1)
104
        ax.set_zlim(-1, 1)
105
        ax.view_init(10., 10.)
106
        plt.show()
107
108
    Parameters
109
    ----------
110
    num_examples : int
111
        Number of datapoints to create.
112
    noise : float
113
        Add normal distributed noise with standard deviation *noise*.
114
115
    """
116
    def __init__(self, num_examples=1000, noise=0.0, **kwargs):
117
        cycles = 1.5
118
        seed = kwargs.pop('seed', config.default_seed)
119
        rng = numpy.random.RandomState(seed)
120
        pos = rng.uniform(size=num_examples, low=0, high=1)
121
        phi = cycles * numpy.pi * (1 + 2*pos)
122
        radius = (1 + 2 * pos) / 3
123
124
        x = radius * numpy.cos(phi)
125
        y = radius * numpy.sin(phi)
126
        z = rng.uniform(size=num_examples, low=-1, high=1)
127
128
        features = numpy.zeros(shape=(num_examples, 3), dtype='float32')
129
        features[:, 0] = x
130
        features[:, 1] = y
131
        features[:, 2] = z
132
        features += noise * rng.normal(size=(num_examples, 3))
133
134
        position = numpy.zeros(shape=(num_examples, 2), dtype='float32')
135
        position[:, 0] = pos
136
        position[:, 1] = z
137
138
        data = OrderedDict([
139
            ('features', features),
140
            ('position', position),
141
        ])
142
143
        super(SwissRoll, self).__init__(data, **kwargs)
144