Passed
Push — main ( 37036d...9cce85 )
by Douglas
04:12
created

VizResources._get_named_cmaps()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 13
nop 1
dl 0
loc 14
rs 9.75
c 0
b 0
f 0
1
"""
0 ignored issues
show
Documentation introduced by
Empty module docstring
Loading history...
2
3
"""
4
from __future__ import annotations
5
6
import enum
7
from pathlib import Path
8
from typing import Union, Optional, Tuple, Mapping, Any, Generator, Sequence, Set
9
10
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
11
import pandas as pd
0 ignored issues
show
introduced by
Unable to import 'pandas'
Loading history...
12
from matplotlib import pyplot as plt
0 ignored issues
show
introduced by
Unable to import 'matplotlib'
Loading history...
13
from matplotlib.colors import LinearSegmentedColormap, Colormap, ListedColormap, to_hex
0 ignored issues
show
introduced by
Unable to import 'matplotlib.colors'
Loading history...
14
from matplotlib.figure import Figure
0 ignored issues
show
introduced by
Unable to import 'matplotlib.figure'
Loading history...
15
from pocketutils.core.dot_dict import NestedDotDict
0 ignored issues
show
introduced by
Unable to import 'pocketutils.core.dot_dict'
Loading history...
16
from pocketutils.tools.common_tools import CommonTools
0 ignored issues
show
introduced by
Unable to import 'pocketutils.tools.common_tools'
Loading history...
17
18
# noinspection PyProtectedMember
19
from seaborn.palettes import SEABORN_PALETTES
0 ignored issues
show
introduced by
Unable to import 'seaborn.palettes'
Loading history...
20
from typeddfs import TypedDfs
0 ignored issues
show
introduced by
Unable to import 'typeddfs'
Loading history...
21
22
from mandos.model.utils.setup import logger
0 ignored issues
show
Unused Code introduced by
Unused logger imported from mandos.model.utils.setup
Loading history...
23
from mandos.model.utils import CleverEnum
24
from mandos.model.utils.misc_utils import MiscUtils
25
from mandos.model.utils.resources import MandosResources
26
27
try:
28
    import seaborn as sns
29
    from matplotlib.axes import Axes
0 ignored issues
show
introduced by
Imports from package matplotlib are not grouped
Loading history...
30
    from matplotlib.figure import Figure
0 ignored issues
show
Unused Code introduced by
The import Figure was already done on line 14. You should be able to
remove this line.
Loading history...
31
except ImportError:
32
    sns = None
33
    Axes = None
34
    Figure = None
35
36
37
class DataType(CleverEnum):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
38
    qualitative = enum.auto()
39
    sequential = enum.auto()
40
    divergent = enum.auto()
41
42
43
class VizResources:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
44
    def __init__(self):
45
        self.override_settings = MandosResources.json_dict("viz", "style_override.json")
46
        self.dims = MandosResources.json_dict("viz", "page_dims.json")
47
        palettes = MandosResources.json_dict("viz", "palettes.json")
48
        self.named_palettes = palettes["named"]
49
        self.default_palettes = palettes["defaults"]
50
        self.named_cmaps = self._get_named_cmaps()
51
52
    def _get_named_cmaps(self) -> Mapping[str, Colormap]:
53
        cmaps = {}
54
        for name, cmap in self.named_palettes.items():
55
            cmap = NestedDotDict(cmap)
56
            seq = cmap.req_list_as("sequence", str)
57
            seq = [to_hex(c) for c in seq]
58
            cat = cmap.req_as("categorical", bool)
59
            if cat:
60
                cmaps[name] = ListedColormap(seq)
61
            else:
62
                nan, under, over = cmap.get("nan"), cmap.get("under"), cmap.get("over")
63
                cmap = LinearSegmentedColormap.from_list(name, seq)
64
                cmap.set_extremes(bad=nan, under=under, over=over)
65
        return cmaps
66
67
68
VIZ_RESOURCES = VizResources()
69
70
71
class MandosPlotStyling:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
72
    @classmethod
73
    def list_named_palettes(cls) -> Set[str]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
74
        return {
75
            *VIZ_RESOURCES.named_palettes.keys(),
76
            *SEABORN_PALETTES,
77
            *plt.colormaps(),
78
        }
79
80
    @classmethod
81
    def choose_palette(
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
82
        cls,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
83
        data: pd.DataFrame,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
84
        col: Optional[str],
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
85
        palette: Optional[str],
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
86
    ) -> Union[None, Colormap, Mapping[str, str]]:
87
        if col is None:
88
            return None
89
        unique = data[col].unique()
90
        dtype = cls.guess_data_type(data)
91
        if palette is None:
92
            palette = cls.get_palette(None, dtype)
93
        if dtype is DataType.qualitative:
94
            if not isinstance(palette, ListedColormap):
95
                raise TypeError(f"{palette} is not a valid choice for {dtype}")
96
            if len(unique) > len(palette.colors):
97
                raise ValueError(
98
                    f"Palette (N={len(palette.colors)}) too small for {len(unique)} items"
99
                )
100
            return {i: j for i, j in CommonTools.zip_strict(unique, map(to_hex, palette.colors))}
0 ignored issues
show
Unused Code introduced by
Unnecessary use of a comprehension
Loading history...
101
        return palette
102
103
    @classmethod
104
    def get_palette(cls, name: Optional[str], data_type: Union[DataType, str]) -> Colormap:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
105
        data_type = DataType.of(data_type)
106
        if name is None:
107
            name = VIZ_RESOURCES.default_palettes[data_type.name]
108
        if name in VIZ_RESOURCES.named_cmaps:
109
            return VIZ_RESOURCES.named_cmaps[name]
110
        return sns.color_palette(name, as_cmap=True)
111
112
    @classmethod
113
    def guess_data_type(cls, data: Sequence[Union[str, float]]) -> DataType:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
114
        numerical = cls._to_numerical(data)
115
        if numerical is None:
116
            return DataType.qualitative
117
        is_divergent = cls._are_floats_divergent(data)
118
        if is_divergent:
119
            return DataType.divergent
120
        return DataType.sequential
121
122
    @classmethod
123
    def _to_colors(cls, data: Sequence[Union[float, str]]) -> Optional[Sequence[str]]:
124
        if not all((isinstance(d, str)) for d in data):
125
            return None
126
        try:
127
            return [to_hex(c) for c in data]
128
        except ValueError:
129
            return None
130
131
    @classmethod
132
    def _to_numerical(cls, data: Sequence[Union[str, float]]) -> Optional[Sequence[float]]:
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
133
        try:
134
            [float(d) for d in data]
135
        except ValueError:
136
            return None
137
138
    @classmethod
139
    def _are_floats_divergent(cls, data: Sequence[float]):
140
        signs = {np.sign(d) for d in data if d != 0 and not np.isnan(d) and not np.isinf(d)}
141
        return len(signs) == 2
142
143
    @classmethod
144
    def context(
145
        cls, style: Union[None, str, Path], kwargs: Optional[Mapping[str, Any]]
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
146
    ) -> Generator[None, None, None]:
147
        """
148
        Override these from the default style.
149
        This will be called once, at startup.
150
        """
151
        new_kwargs = dict(VIZ_RESOURCES.override_settings["allow_change"])
152
        if kwargs is not None:
153
            new_kwargs.update(kwargs)
154
        with plt.rc_context(new_kwargs, style):
155
            yield
156
157
    @classmethod
158
    def fig_width_and_height(cls, size: str) -> Tuple[float, float]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
159
        if size is None:
160
            return plt.rcParams["figure.figsize"]
161
        axis_to_str = {
162
            i: d.strip() for i, d in enumerate(size.replace(" × ", " by ").split(" by "))
163
        }
164
        try:
165
            default_inch = plt.rcParams["figure.figsize"]
166
            width = cls._to_inch(axis_to_str.get(0), VIZ_RESOURCES.dims["widths"], default_inch[0])
167
            height = cls._to_inch(
168
                axis_to_str.get(1), VIZ_RESOURCES.dims["heights"], default_inch[1]
169
            )
170
        except ValueError:
171
            raise ValueError(f"Strange --size format in '{size}'")
172
        return width, height
173
174
    @classmethod
175
    def _to_inch(
0 ignored issues
show
Coding Style Naming introduced by
Argument name "s" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
176
        cls, s: Optional[str], standards: Mapping[str, float], default_inch: float
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
177
    ) -> float:
178
        if s is None or len(s) == "":
179
            return default_inch
180
        try:
181
            return float(s)
182
        except ValueError:
183
            pass
184
        x = standards.get(s, s)
0 ignored issues
show
Coding Style Naming introduced by
Variable name "x" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
185
        return MiscUtils.canonicalize_quantity(x, "[length]").to("inch").magnitude
186
187
188
class MandosPlotUtils:
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
189
    @classmethod
190
    def save(cls, figure: Figure, path: Path) -> None:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
191
        path.parent.mkdir(parents=True, exist_ok=True)
192
        figure.savefig(str(path))
193
        figure.clear()
194
195
196
CompoundStyleDf = (
197
    TypedDfs.typed("CompoundStyleDf").require("inchikey", dtype=str).strict(cols=False).secure()
198
).build()
199
200
PredicateObjectStyleDf = (
201
    TypedDfs.typed("PredicateObjectStyleDf")
202
    .require("predicate", "object", dtype=str)
203
    .strict(cols=False)
204
    .secure()
205
).build()
206
207
PhiPsiStyleDf = (
208
    TypedDfs.typed("PhiPsiStyleDf").require("phi", "psi", dtype=str).strict(cols=False).secure()
209
).build()
210
211
212
__all__ = [
213
    "sns",
214
    "plt",
215
    "Figure",
216
    "Axes",
217
    "MandosPlotStyling",
218
    "MandosPlotUtils",
219
    "CompoundStyleDf",
220
    "PredicateObjectStyleDf",
221
    "VIZ_RESOURCES",
222
]
223