Passed
Push — main ( a79885...50e5ea )
by Douglas
02:25
created

fig_tools.FigureTools.plot1d()   A

Complexity

Conditions 1

Size

Total Lines 33
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 16
nop 8
dl 0
loc 33
rs 9.6
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from __future__ import annotations
2
3
import logging
4
from contextlib import contextmanager
5
from copy import copy
6
from pathlib import Path
7
from typing import Callable, Generator, Iterator, Mapping, Optional, Sequence, Iterable
8
from typing import Tuple as Tup
9
from typing import Union
10
11
import matplotlib
12
import matplotlib.pyplot as plt
13
import numpy as np
14
import pandas as pd
15
from matplotlib import colors as mcolors
16
from matplotlib.axes import Axes
17
from matplotlib.figure import Figure
18
19
from pocketutils.plotting.corners import Corner
20
from pocketutils.tools.common_tools import CommonTools
21
22
FigureSeqLike = Union[Figure, Iterator[Figure], Iterator[Tup[str, Figure]], Mapping[str, Figure]]
23
KNOWN_EXTENSIONS = ["jpg", "png", "pdf", "svg", "eps", "ps"]
24
logger = logging.getLogger("pocketutils")
25
26
27
class FigureTools:
28
    @classmethod
29
    def cm2in(cls, tup: Union[float, Iterable[float]]):
30
        """
31
        Just converts centimeters to inches.
32
33
        Args:
34
            tup: A float or sequence of floats (determines the return type)
35
        """
36
        if CommonTools.is_true_iterable(tup):
37
            return [x / 2.54 for x in tup]
38
        else:
39
            return float(tup) / 2.54
40
41
    @classmethod
42
    def in2cm(cls, tup: Union[float, Iterable[float]]):
43
        """
44
        Just converts inches to centimeters.
45
46
        Args:
47
            tup: A float or sequence of floats (determines the return type)
48
        """
49
        if CommonTools.is_true_iterable(tup):
50
            return [x * 2.54 for x in tup]
51
        else:
52
            return float(tup) * 2.54
53
54
    @classmethod
55
    def open_figs(cls) -> Sequence[Figure]:
56
        """
57
        Returns all currently open figures.
58
        """
59
        return [plt.figure(num=i) for i in plt.get_fignums()]
60
61
    @classmethod
62
    def open_fig_map(cls) -> Mapping[str, Figure]:
63
        """
64
        Returns all currently open figures as a dict mapping their labels `Figure.label` to their instances.
65
        Note that `Figure.label` is often empty in practice.
66
67
        Args:
68
69
        Returns:
70
71
        """
72
        return {label: plt.figure(label=label) for label in plt.get_figlabels()}
73
74
    @classmethod
75
    @contextmanager
76
    def clearing(cls, yes: bool = True) -> Generator[None, None, None]:
77
        """
78
        Context manager to clear and close all figures created during its lifespan.
79
        When the context manager exits, calls `clf` and `close` on all figures created under it.
80
81
        Args:
82
            yes: If False, does nothing
83
84
        Yields:
85
86
        """
87
        oldfigs = copy(plt.get_fignums())
88
        yield
89
        if yes:
90
            for fig in [plt.figure(num=i) for i in plt.get_fignums() if i not in oldfigs]:
91
                fig.clf()
92
                plt.close(fig)
93
94
    @classmethod
95
    @contextmanager
96
    def hiding(cls, yes: bool = True) -> Generator[None, None, None]:
97
        """
98
        Context manager to hide figure display by setting `plt.interactive(False)`.
99
100
        Args:
101
            yes: If False, does nothing
102
103
        Yields:
104
105
        """
106
        isint = plt.isinteractive()
107
        if yes:
108
            plt.interactive(False)
109
        yield
110
        if yes:
111
            plt.interactive(isint)
112
113
    @classmethod
114
    def plot1d(
115
        cls,
116
        values: np.array,
117
        figsize: Optional[Tup[float, float]] = None,
118
        x0=None,
119
        y0=None,
120
        x1=None,
121
        y1=None,
122
        **kwargs,
123
    ) -> Axes:
124
        """
125
        Plots a 1D array and returns the axes.
126
        kwargs are passed to `Axes.plot`.
127
128
        Args:
129
            values: np.array:
130
            figsize:
131
            x0:  (Default value = None)
132
            y0:  (Default value = None)
133
            x1:  (Default value = None)
134
            y1:  (Default value = None)
135
            **kwargs:
136
137
        Returns:
138
139
        """
140
        figure = plt.figure(figsize=figsize)
141
        ax = figure.add_subplot(1, 1, 1)  # Axes
142
        ax.plot(values, **kwargs)
143
        ax.set_xlim((x0, x1))
144
        ax.set_ylim((y0, y1))
145
        return ax
146
147
    @classmethod
148
    def despine(cls, ax: Axes) -> Axes:
149
        """
150
        Removes all spines and ticks on an Axes.
151
152
        Args:
153
            ax: Axes:
154
155
        Returns:
156
157
        """
158
        ax.set_yticks([])
159
        ax.set_yticks([])
160
        ax.set_xticklabels([])
161
        ax.set_yticklabels([])
162
        ax.spines["top"].set_visible(False)
163
        ax.spines["bottom"].set_visible(False)
164
        ax.spines["right"].set_visible(False)
165
        ax.spines["left"].set_visible(False)
166
        ax.get_xaxis().set_ticks([])
167
        ax.get_yaxis().set_ticks([])
168
        return ax
169
170
    @classmethod
171
    def clear(cls) -> int:
172
        """
173
        Removes all matplotlib figures from memory.
174
        Here because it's confusing to remember.
175
        Logs an error if not all figures were closed.
176
177
        Returns:
178
            The number of closed figures
179
180
        """
181
        n = len(plt.get_fignums())
182
        plt.clf()
183
        plt.close("all")
184
        m = len(plt.get_fignums())
185
        if m == 0:
186
            logger.debug(f"Cleared {n} figure{'s' if n>1 else ''}.")
187
        else:
188
            logger.error(f"Failed to close figures. Cleared {n - m}; {m} remain.")
189
        return n
190
191
    @classmethod
192
    def font_paths(cls) -> Sequence[Path]:
193
        """
194
        Returns the paths of system fonts.
195
        """
196
        # noinspection PyUnresolvedReferences
197
        return [Path(p) for p in matplotlib.font_manager.findSystemFonts(fontpaths=None)]
198
199
    @classmethod
200
    def text_matrix(
201
        cls,
202
        ax: Axes,
203
        data: pd.DataFrame,
204
        color_fn: Optional[Callable[[str], str]] = None,
205
        adjust_x: float = 0,
206
        adjust_y: float = 0,
207
        **kwargs,
208
    ) -> None:
209
        """
210
        Adds a matrix of text.
211
212
        Args:
213
            ax: Axes
214
            data: The matrix of any text values; will be converted to strings and empty strings will be ignored
215
            color_fn: An optional function mapping (pre-conversion-to-str) values to colors
216
            adjust_x: Add this value to the x coordinates
217
            adjust_y: Add this value to the y coordinates
218
            **kwargs: Passed to `ax.text`
219
220
        """
221
        for r, row in enumerate(data.index):
222
            for c, col in enumerate(data.columns):
223
                value = data.iat[r, c]
224
                if str(value) != "":
225
                    ax.text(
226
                        r + adjust_x,
227
                        c + adjust_y,
228
                        str(value),
229
                        color=None if color_fn is None else color_fn(value),
230
                        **kwargs,
231
                    )
232
233
    @classmethod
234
    def add_note_01_coords(cls, ax: Axes, x: float, y: float, s: str, **kwargs) -> Axes:
235
        """
236
        Adds text without a box, using chemfish_rc['general_note_font_size'] (unless overridden in kwargs).
237
        ``x`` and ``y`` are in coordinates (0, 1).
238
239
        Args:
240
            ax: Axes:
241
            x: float:
242
            y: float:
243
            s: str:
244
            **kwargs:
245
246
        Returns:
247
248
        """
249
        t = ax.text(x, y, s=s, transform=ax.transAxes, **kwargs)
250
        t.set_bbox(dict(alpha=0.0))
251
        return ax
252
253
    @classmethod
254
    def add_note_data_coords(cls, ax: Axes, x: float, y: float, s: str, **kwargs) -> Axes:
255
        """
256
        Adds text without a box, using chemfish_rc['general_note_font_size'] (unless overridden in kwargs).
257
        ``x`` and ``y`` are in data coordinates.
258
259
        Args:
260
            ax: Axes:
261
            x: float:
262
            y: float:
263
            s: str:
264
            **kwargs:
265
266
        Returns:
267
268
        """
269
        t = ax.text(x, y, s=s, **kwargs)
270
        t.set_bbox(dict(alpha=0.0))
271
        return ax
272
273
    @classmethod
274
    def stamp(cls, ax: Axes, text: str, corner: Corner, **kwargs) -> Axes:
275
        """
276
        Adds a "stamp" in the corner.
277
278
        Example:
279
            Stamping::
280
281
                FigureTools.stamp(ax, 'hello', Corners.TOP_RIGHT)
282
283
        Args:
284
            ax: Axes:
285
            text: str:
286
            corner: Corner:
287
            **kwargs:
288
289
        Returns:
290
291
        """
292
        return cls._text(ax, text, corner, **kwargs)
293
294
    @classmethod
295
    def _text(cls, ax: Axes, text: str, corner: Corner, **kwargs) -> Axes:
296
        """
297
298
299
        Args:
300
            ax: Axes:
301
            text: str:
302
            corner: Corner:
303
            **kwargs:
304
305
        Returns:
306
307
        """
308
        t = ax.text(s=text, **corner.params(), transform=ax.transAxes, **kwargs)
309
        t.set_bbox(dict(alpha=0.0))
310
        return ax
311
312
    @classmethod
313
    def plot_palette(cls, values: Union[Sequence[str], str]) -> Figure:
314
        """
315
        Plots a color palette.
316
317
        Args:
318
            values: A string of a color (starting with #), a sequence of colors (each starting with #)
319
320
        Returns:
321
322
        """
323
        n = len(values)
324
        figure = plt.figure(figsize=(8.0, 2.0))
325
        ax = figure.add_subplot(1, 1, 1)
326
        ax.imshow(
327
            np.arange(n).reshape(1, n),
328
            cmap=mcolors.ListedColormap(values),
329
            interpolation="none",
330
            aspect="auto",
331
        )
332
        cls.despine(ax)
333
        return figure
334
335
336
__all__ = ["FigureTools"]
337