Test Failed
Push — master ( 3b6ed6...9f250d )
by Ben
02:11
created

Axes._corner()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 2
rs 10
1
"""Provide the basic Axes class."""
2
import json
3
import os
4
from collections import OrderedDict
5
from functools import wraps
6
7
import h5py
8
import matplotlib.pyplot as plt
9
10
11
class HDF5IO(object):
12
    """
13
    """
14
15
    def __init__(self, data_file):
16
        """
17
        """
18
        self.data_file = data_file
19
20
    def read(self):
21
        """Read the hdf5 file."""
22
        with h5py.File(self.data_file, 'r') as f:
23
            d = OrderedDict(
24
                sorted([(plot_function,
25
                         {arg_name: OrderedDict(
26
                             sorted([(str(k), v[()])
27
                                     for k, v in arg.items()]))
28
                          for arg_name, arg in args.items()})
29
                        for plot_function, args in f.items()
30
                        if plot_function != 'rcParams']))
31
            if 'rcParams' in f:
32
                d['rcParams'] = dict((key, val[()])
33
                                     for key, val in f['rcParams'].items())
34
        return d
35
36
    def save(self, name, plot_object):
37
        """Save a plot object to the hdf5 file."""
38
        with h5py.File(self.data_file) as f:
39
            plot_object_group = f.require_group(name)
40
            if name == 'rcParams':
41
                for param, v in plot_object.items():
42
                    plot_object_group[param] = v
43
44
            elif name == 'style':
45
                for style, dct in plot_object.items():
46
                    style_group = plot_object_group.require_group(style)
47
                    for k, v in plot_object[style].items():
48
                        style_group[k] = v
49
            else:
50
                args_group = plot_object_group.require_group('args')
51
                kwargs_group = plot_object_group.require_group('kwargs')
52
                for k, arg in plot_object['args'].items():
53
                    args_group[str(k)] = arg
54
                for k, v in plot_object['kwargs'].items():
55
                    kwargs_group[str(k)] = v
56
57
58
class JsonIO(object):
59
    """
60
    """
61
62
    def __init__(self, data_file):
63
        """
64
        """
65
        self.data_file = data_file
66
67
    def read(self):
68
        with open(self.data_file) as f:
69
            data = json.load(f, object_pairs_hook=OrderedDict)
70
        return data
71
72
    def save(self, name, dct_obj):
73
        if name not in ('style', 'rcParams'):
74
            for key, v in dct_obj['args'].items():
75
                try:
76
                    dct_obj['args'][key] = v.tolist()
77
                except AttributeError:
78
                    pass
79
            for key, v in dct_obj['kwargs'].items():
80
                try:
81
                    dct_obj['kwargs'][key] = v.tolist()
82
                except AttributeError:
83
                    pass
84
85
        dct = {name: dct_obj}
86
        try:
87
            data = self.read()
88
        except IOError:
89
            data = {}
90
        data.update(dct)
91
        with open(self.data_file, 'w') as f:
92
            json.dump(data, f, sort_keys=True,
93
                      indent=0, separators=(',', ': '))
94
95
96
class Axes(object):
97
98
    """
99
    Save matplotlib command for later reuse.
100
101
    Holds and `matplotlib.axes.Axes` object which saves the operations done on
102
    it for latter re-plotting.
103
    """
104
105
    def __init__(self, data_file, ax=None, file_type='json', style={},
106
                 rcParams={}, erase=True):
107
        """
108
        Save matplotlib command for later reuse.
109
110
        Arguments:
111
112
        `data_file` -- the file on which the plotting functions and
113
                       data are stored
114
115
        `ax` -- an `matplotlib.axes.Axes` instance (default: the
116
                       current `matplotlib.axes.Axes` instance)
117
118
        """
119
        self._action_number = 0
120
        self.file_type = file_type
121
        self.data_file = data_file
122
        self._style = style
123
        self._rcParams = rcParams
124
125
        if erase:
126
            try:
127
                os.remove(self.data_file)
128
            except:
129
                pass
130
131
        if file_type == 'json':
132
            self.io = JsonIO(self.data_file)
133
        elif file_type == 'hdf5':
134
            self.io = HDF5IO(self.data_file)
135
        else:
136
            raise NotImplementedError(self.file_type)
137
138
        try:
139
            style = self.io.read()['style']
140
            style.update(self._style)
141
            self._style = style
142
        except (IOError, KeyError):
143
            pass
144
145
        try:
146
            rcParams = self.io.read()['rcParams']
147
            rcParams.update(self._rcParams)
148
            self._rcParams = rcParams
149
        except (IOError, KeyError):
150
            pass
151
152
        plt.rcParams.update(self.rcParams)
153
        self._ax = ax if ax else plt.gca()
154
155
    @property
156
    def action_number(self):
157
        """Number of action called."""
158
        self._action_number += 1
159
        return '{:03d}'.format(self._action_number)
160
161
    def __getattr__(self, attr):
162
        """Pass the plotting function to the parser."""
163
        if attr[1] == '_':
164
            raise AttributeError
165
        try:
166
            return self.parse_func(getattr(self._ax, attr))
167
        except AttributeError:
168
            return self.parse_func(getattr(self, '_' + attr))
169
170
    @property
171
    def style(self):
172
        return self._style
173
174
    @style.setter
175
    def style(self, dct):
176
        self._style.update(dct)
177
        self.io.save('style', self.style)
178
179
    @property
180
    def rcParams(self):
181
        return self._rcParams
182
183
    @rcParams.setter
184
    def rcParams(self, dct):
185
        self._rcParams.update(dct)
186
        self.io.save('rcParams', self.rcParams)
187
188
    def replot(self):
189
        """Replot using the recorded ploting funtions."""
190
        for key, plot_object in self.io.read().items():
191
            if key not in ('style', 'rcParams'):
192
                attr = '_'.join(key.split('_')[1:])
193
                kwargs = plot_object['kwargs']
194
                self.apply_style(kwargs, attr)
195
                try:
196
                    getattr(self._ax, attr)(
197
                        *plot_object['args'].values(), **kwargs)
198
                except AttributeError:
199
                    getattr(self, attr)(
200
                        *plot_object['args'].values(), **kwargs)
201
202
        plt.draw_if_interactive()
203
204
    def _to_dict(self, name, *args, **kwargs):
205
        return self.action_number + '_' + name, {'args': dict(enumerate(args)),
206
                                                 'kwargs': kwargs}
207
208
    def parse_func(self, func):
209
        """Create and save a plot object by parsing the ploting function."""
210
        @wraps(func)
211
        def wrapper(*args, **kwargs):
212
            self.io.save(*self._to_dict(func.__name__, *args, **kwargs))
213
            self.apply_style(kwargs, func.__name__)
214
            return func(*args, **kwargs)
215
        return wrapper
216
217
    def apply_style(self, kwargs, extra_styles=None):
218
        if extra_styles:
219
            kwargs['style'] = (extra_styles + ' ' +
220
                               kwargs.get('style', ' ')).strip()
221
        try:
222
            for style in kwargs.pop('style').split(' '):
223
                try:
224
                    kwargs.update(self.style[style])
225
                except KeyError:
226
                    pass
227
228
        except KeyError:
229
            pass
230