Completed
Push — master ( c9afee...7a23b6 )
by Ben
01:16
created

Axes.update()   A

Complexity

Conditions 3

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 11
rs 9.4285
1
"""Provide the basic Axes class."""
2
import json
3
import os
4
from collections import OrderedDict
5
from functools import wraps
6
7
import h5py
8
import matplotlib.pyplot as plt
9
10
11
class HDF5IO(object):
12
    """
13
    """
14
15
    def __init__(self, data_file):
16
        """
17
        """
18
        self.data_file = data_file
19
20
    def read(self):
21
        """Read the hdf5 file."""
22
        with h5py.File(self.data_file, 'r') as h5_file:
23
            dct = OrderedDict(
24
                sorted([(plot_function,
25
                         {arg_name: OrderedDict(
26
                             sorted([(str(key), value[()])
27
                                     for key, value in arg.items()]))
28
                          for arg_name, arg in args.items()})
29
                        for plot_function, args in h5_file.items()
30
                        if plot_function != 'rcParams']))
31
            if 'rcParams' in h5_file:
32
                dct['rcParams'] = dict((key, val[()])
33
                                       for key, val in
34
                                       h5_file['rcParams'].items())
35
        return dct
36
37
    def save(self, name, plot_object):
38
        """Save a plot object to the hdf5 file."""
39
        with h5py.File(self.data_file) as h5_file:
40
            plot_object_group = h5_file.require_group(name)
41
            if name == 'rcParams':
42
                for param, value in plot_object.items():
43
                    plot_object_group[param] = value
44
45
            elif name == 'style':
46
                for style, dct in plot_object.items():
47
                    style_group = plot_object_group.require_group(style)
48
                    for key, value in dct.items():
49
                        style_group[key] = value
50
            else:
51
                args_group = plot_object_group.require_group('args')
52
                kwargs_group = plot_object_group.require_group('kwargs')
53
                for key, arg in plot_object['args'].items():
54
                    args_group[str(key)] = arg
55
                for key, value in plot_object['kwargs'].items():
56
                    kwargs_group[str(key)] = value
57
58
59
class JsonIO(object):
60
    """
61
    """
62
63
    def __init__(self, data_file):
64
        """
65
        """
66
        self.data_file = data_file
67
68
    def read(self):
69
        with open(self.data_file) as f:
70
            data = json.load(f, object_pairs_hook=OrderedDict)
71
        return data
72
73
    def save(self, name, dct_obj):
74
        if name not in ('style', 'rcParams'):
75
            for key, value in dct_obj['args'].items():
76
                try:
77
                    dct_obj['args'][key] = value.tolist()
78
                except AttributeError:
79
                    pass
80
            for key, value in dct_obj['kwargs'].items():
81
                try:
82
                    dct_obj['kwargs'][key] = value.tolist()
83
                except AttributeError:
84
                    pass
85
86
        dct = {name: dct_obj}
87
        try:
88
            data = self.read()
89
        except IOError:
90
            data = {}
91
        data.update(dct)
92
        with open(self.data_file, 'w') as json_file:
93
            json.dump(data, json_file, sort_keys=True,
94
                      indent=0, separators=(',', ': '))
95
96
97
class Axes(object):
98
99
    """
100
    Save matplotlib command for later reuse.
101
102
    Holds and `matplotlib.axes.Axes` object which saves the operations done on
103
    it for latter re-plotting.
104
    """
105
106
    def __init__(self, data_file, ax=None, file_type='json', style=None,
107
                 rcParams=None, erase=True):
108
        """
109
        Save matplotlib command for later reuse.
110
111
        Arguments:
112
113
        `data_file` -- the file on which the plotting functions and
114
                       data are stored
115
116
        `ax` -- an `matplotlib.axes.Axes` instance (default: the
117
                       current `matplotlib.axes.Axes` instance)
118
119
        """
120
        self._action_number = 0
121
        self.file_type = file_type
122
        self.data_file = data_file
123
124
        if erase:
125
            self.clean()
126
127
        if self.file_type == 'json':
128
            self.io = JsonIO(self.data_file)
129
        elif self.file_type == 'hdf5':
130
            self.io = HDF5IO(self.data_file)
131
        else:
132
            raise NotImplementedError(self.file_type)
133
134
        self._style = self.update('style', style)
135
        self._rcParams = self.update('rcParams', rcParams)
136
137
        plt.rcParams.update(self.rcParams)
138
        self._ax = ax if ax else plt.gca()
139
140
    def clean(self):
141
        """
142
        Delete the file
143
        """
144
        try:
145
            os.remove(self.data_file)
146
        except OSError:
147
            pass
148
149
    @property
150
    def action_number(self):
151
        """Number of action called."""
152
        self._action_number += 1
153
        return '{:03d}'.format(self._action_number)
154
155
    def __getattr__(self, attr):
156
        """Pass the plotting function to the parser."""
157
        if attr[1] == '_':
158
            raise AttributeError
159
        try:
160
            return self.parse_func(getattr(self._ax, attr))
161
        except AttributeError:
162
            return self.parse_func(getattr(self, '_' + attr))
163
164
    @property
165
    def style(self):
166
        return self._style
167
168
    @style.setter
169
    def style(self, dct):
170
        self._style.update(dct)
171
        self.io.save('style', self.style)
172
173
    @property
174
    def rcParams(self):
175
        return self._rcParams
176
177
    @rcParams.setter
178
    def rcParams(self, dct):
179
        self._rcParams.update(dct)
180
        self.io.save('rcParams', self.rcParams)
181
182
    def replot(self):
183
        """Replot using the recorded ploting funtions."""
184
        for key, plot_object in self.io.read().items():
185
            if key not in ('style', 'rcParams'):
186
                attr = '_'.join(key.split('_')[1:])
187
                kwargs = plot_object['kwargs']
188
                self.apply_style(kwargs, attr)
189
                try:
190
                    getattr(self._ax, attr)(
191
                        *plot_object['args'].values(), **kwargs)
192
                except AttributeError:
193
                    getattr(self, attr)(
194
                        *plot_object['args'].values(), **kwargs)
195
196
        plt.draw_if_interactive()
197
198
    def _to_dict(self, name, *args, **kwargs):
199
        return self.action_number + '_' + name, {'args': dict(enumerate(args)),
200
                                                 'kwargs': kwargs}
201
202
    def parse_func(self, func):
203
        """Create and save a plot object by parsing the ploting function."""
204
        @wraps(func)
205
        def wrapper(*args, **kwargs):
206
            self.io.save(*self._to_dict(func.__name__, *args, **kwargs))
207
            self.apply_style(kwargs, func.__name__)
208
            return func(*args, **kwargs)
209
        return wrapper
210
211
    def apply_style(self, kwargs, extra_styles=None):
212
        if extra_styles:
213
            kwargs['style'] = (extra_styles + ' ' +
214
                               kwargs.get('style', ' ')).strip()
215
        try:
216
            for style in kwargs.pop('style').split(' '):
217
                try:
218
                    kwargs.update(self.style[style])
219
                except KeyError:
220
                    pass
221
222
        except KeyError:
223
            pass
224
225
    def update(self, ex_dict, new_dict):
226
        """
227
        Updated existing dictionary named `ex_dict` in database with `new_dict`
228
        """
229
        new_dict = new_dict if new_dict else {}
230
        try:
231
            dct = self.io.read()[ex_dict]
232
            dct.update(new_dict)
233
            return dct
234
        except (IOError, KeyError):
235
            return new_dict
236