Test Failed
Push — master ( 382fbf...8da094 )
by Ben
01:48
created

Axes.__init__()   F

Complexity

Conditions 10

Size

Total Lines 49

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 1 Features 0
Metric Value
cc 10
c 2
b 1
f 0
dl 0
loc 49
rs 3.7894

How to fix   Complexity   

Complexity

Complex classes like Axes.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Provide the basic Axes class."""
2
import json
3
import os
4
from collections import OrderedDict
5
from functools import wraps
6
7
import h5py
8
import matplotlib.pyplot as plt
9
10
11
class HDF5IO(object):
12
    """
13
    """
14
15
    def __init__(self, data_file):
16
        """
17
        """
18
        self.data_file = data_file
19
20
    def read(self):
21
        """Read the hdf5 file."""
22
        with h5py.File(self.data_file, 'r') as h5_file:
23
            dct = OrderedDict(
24
                sorted([(plot_function,
25
                         {arg_name: OrderedDict(
26
                             sorted([(str(key), value[()])
27
                                     for key, value in arg.items()]))
28
                          for arg_name, arg in args.items()})
29
                        for plot_function, args in h5_file.items()
30
                        if plot_function != 'rcParams']))
31
            if 'rcParams' in h5_file:
32
                dct['rcParams'] = dict((key, val[()])
33
                                       for key, val in
34
                                       h5_file['rcParams'].items())
35
        return dct
36
37
    def save(self, name, plot_object):
38
        """Save a plot object to the hdf5 file."""
39
        with h5py.File(self.data_file) as h5_file:
40
            plot_object_group = h5_file.require_group(name)
41
            if name == 'rcParams':
42
                for param, value in plot_object.items():
43
                    plot_object_group[param] = value
44
45
            elif name == 'style':
46
                for style, dct in plot_object.items():
47
                    style_group = plot_object_group.require_group(style)
48
                    for key, value in dct.items():
49
                        style_group[key] = value
50
            else:
51
                args_group = plot_object_group.require_group('args')
52
                kwargs_group = plot_object_group.require_group('kwargs')
53
                for key, arg in plot_object['args'].items():
54
                    args_group[str(key)] = arg
55
                for key, value in plot_object['kwargs'].items():
56
                    kwargs_group[str(key)] = value
57
58
59
class JsonIO(object):
60
    """
61
    """
62
63
    def __init__(self, data_file):
64
        """
65
        """
66
        self.data_file = data_file
67
68
    def read(self):
69
        with open(self.data_file) as f:
70
            data = json.load(f, object_pairs_hook=OrderedDict)
71
        return data
72
73
    def save(self, name, dct_obj):
74
        if name not in ('style', 'rcParams'):
75
            for key, value in dct_obj['args'].items():
76
                try:
77
                    dct_obj['args'][key] = value.tolist()
78
                except AttributeError:
79
                    pass
80
            for key, value in dct_obj['kwargs'].items():
81
                try:
82
                    dct_obj['kwargs'][key] = value.tolist()
83
                except AttributeError:
84
                    pass
85
86
        dct = {name: dct_obj}
87
        try:
88
            data = self.read()
89
        except IOError:
90
            data = {}
91
        data.update(dct)
92
        with open(self.data_file, 'w') as json_file:
93
            json.dump(data, json_file, sort_keys=True,
94
                      indent=0, separators=(',', ': '))
95
96
97
class Axes(object):
98
99
    """
100
    Save matplotlib command for later reuse.
101
102
    Holds and `matplotlib.axes.Axes` object which saves the operations done on
103
    it for latter re-plotting.
104
    """
105
106
    def __init__(self, data_file, ax=None, file_type='json', style=None,
107
                 rcParams=None, erase=True):
108
        """
109
        Save matplotlib command for later reuse.
110
111
        Arguments:
112
113
        `data_file` -- the file on which the plotting functions and
114
                       data are stored
115
116
        `ax` -- an `matplotlib.axes.Axes` instance (default: the
117
                       current `matplotlib.axes.Axes` instance)
118
119
        """
120
        self._action_number = 0
121
        self.file_type = file_type
122
        self.data_file = data_file
123
        self._style = style if style else {}
124
        self._rcParams = rcParams if rcParams else {}
125
126
        if erase:
127
            try:
128
                os.remove(self.data_file)
129
            except OSError:
130
                pass
131
132
        if file_type == 'json':
133
            self.io = JsonIO(self.data_file)
134
        elif file_type == 'hdf5':
135
            self.io = HDF5IO(self.data_file)
136
        else:
137
            raise NotImplementedError(self.file_type)
138
139
        try:
140
            style = self.io.read()['style']
141
            style.update(self._style)
142
            self._style = style
143
        except (IOError, KeyError):
144
            pass
145
146
        try:
147
            rcParams = self.io.read()['rcParams']
148
            rcParams.update(self._rcParams)
149
            self._rcParams = rcParams
150
        except (IOError, KeyError):
151
            pass
152
153
        plt.rcParams.update(self.rcParams)
154
        self._ax = ax if ax else plt.gca()
155
156
    @property
157
    def action_number(self):
158
        """Number of action called."""
159
        self._action_number += 1
160
        return '{:03d}'.format(self._action_number)
161
162
    def __getattr__(self, attr):
163
        """Pass the plotting function to the parser."""
164
        if attr[1] == '_':
165
            raise AttributeError
166
        try:
167
            return self.parse_func(getattr(self._ax, attr))
168
        except AttributeError:
169
            return self.parse_func(getattr(self, '_' + attr))
170
171
    @property
172
    def style(self):
173
        return self._style
174
175
    @style.setter
176
    def style(self, dct):
177
        self._style.update(dct)
178
        self.io.save('style', self.style)
179
180
    @property
181
    def rcParams(self):
182
        return self._rcParams
183
184
    @rcParams.setter
185
    def rcParams(self, dct):
186
        self._rcParams.update(dct)
187
        self.io.save('rcParams', self.rcParams)
188
189
    def replot(self):
190
        """Replot using the recorded ploting funtions."""
191
        for key, plot_object in self.io.read().items():
192
            if key not in ('style', 'rcParams'):
193
                attr = '_'.join(key.split('_')[1:])
194
                kwargs = plot_object['kwargs']
195
                self.apply_style(kwargs, attr)
196
                try:
197
                    getattr(self._ax, attr)(
198
                        *plot_object['args'].values(), **kwargs)
199
                except AttributeError:
200
                    getattr(self, attr)(
201
                        *plot_object['args'].values(), **kwargs)
202
203
        plt.draw_if_interactive()
204
205
    def _to_dict(self, name, *args, **kwargs):
206
        return self.action_number + '_' + name, {'args': dict(enumerate(args)),
207
                                                 'kwargs': kwargs}
208
209
    def parse_func(self, func):
210
        """Create and save a plot object by parsing the ploting function."""
211
        @wraps(func)
212
        def wrapper(*args, **kwargs):
213
            self.io.save(*self._to_dict(func.__name__, *args, **kwargs))
214
            self.apply_style(kwargs, func.__name__)
215
            return func(*args, **kwargs)
216
        return wrapper
217
218
    def apply_style(self, kwargs, extra_styles=None):
219
        if extra_styles:
220
            kwargs['style'] = (extra_styles + ' ' +
221
                               kwargs.get('style', ' ')).strip()
222
        try:
223
            for style in kwargs.pop('style').split(' '):
224
                try:
225
                    kwargs.update(self.style[style])
226
                except KeyError:
227
                    pass
228
229
        except KeyError:
230
            pass
231