Completed
Pull Request — master (#1)
by Andy
01:12
created

AnniesLasso.is_structured_label_vector()   D

Complexity

Conditions 10

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 22
rs 4.5957
cc 10

How to fix   Complexity   

Complexity

Complex classes like AnniesLasso.is_structured_label_vector() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
General utility functions.
6
"""
7
8
__all__ = ["label_vector", "progressbar", "short_hash"]
9
10
import logging
11
import numpy as np
12
import sys
13
from time import time
14
from collections import (Counter, Iterable, OrderedDict)
15
from hashlib import md5
16
from itertools import combinations_with_replacement
17
from six import string_types
18
19
logger = logging.getLogger(__name__)
20
21
22
def short_hash(contents):
23
    """
24
    Return a short hash string of some iterable content.
25
26
    :param contents:
27
        The contents to calculate a hash for.
28
29
    :returns:
30
        A concatenated string of 10-character length hashes for all items in the
31
        contents provided.
32
    """
33
    if not isinstance(contents, Iterable): contents = [contents]
34
    return "".join([str(md5(str(item).encode("utf-8")).hexdigest())[:10] \
35
        for item in contents])
36
37
38
def is_structured_label_vector(label_vector):
39
    """
40
    Return whether the provided label vector is structured correctly.
41
    """
42
43
    if not isinstance(label_vector, (list, tuple)):
44
        return False
45
46
    for descriptor in label_vector:
47
        if not isinstance(descriptor, (list, tuple)):
48
            return False
49
50
        for term in descriptor:
51
            if not isinstance(term, (list, tuple)) \
52
            or len(term) != 2 \
53
            or not isinstance(term[-1], (int, float)):
54
                return False
55
56
    if len(label_vector) == 0 or sum(map(len, label_vector)) == 0:
57
        return False
58
59
    return True
60
61
62
def parse_label_vector(label_vector_description, columns=None, **kwargs):
63
    """
64
    Return a structured form of a label vector from unstructured,
65
    human-readable input.
66
67
    :param label_vector_description:
68
        A human-readable or structured form of a label vector.
69
70
    :type label_vector_description:
71
        str or list
72
73
    :param columns: [optional]
74
        If `columns` are provided, instead of text columns being provided as the
75
        output parameter, the corresponding index location in `column` will be
76
        given.
77
78
    :returns:
79
        A structured form of the label vector as a multi-level list.
80
81
82
    :Example:
83
84
    >>> parse_label_vector("Teff^4 + logg*Teff^3 + feh + feh^0*Teff")
85
    [
86
        [
87
            ("Teff", 4),
88
        ],
89
        [
90
            ("logg", 1),
91
            ("Teff", 3)
92
        ],
93
        [
94
            ("feh", 1),
95
        ],
96
        [
97
            ("feh", 0),
98
            ("Teff", 1)
99
        ]
100
    ]
101
    """
102
103
    if is_structured_label_vector(label_vector_description):
104
        return label_vector_description
105
106
    # Allow for custom characters, but don't advertise it.
107
    # (Astronomers have bad enough habits already.)
108
    kwds = dict(zip(("sep", "mul", "pow"), "+*^"))
109
    kwds.update(kwargs)
110
    sep, mul, pow = (kwds[k] for k in ("sep", "mul", "pow"))
111
112
    if isinstance(label_vector_description, string_types):
113
        label_vector_description = label_vector_description.split(sep)
114
    label_vector_description = map(str.strip, label_vector_description)
115
116
    # Functions to parse the parameter (or index) and order for each term.
117
    get_power = lambda t: float(t.split(pow)[1].strip()) if pow in t else 1
118
    if columns is None:
119
        get_label = lambda d: d.split(pow)[0].strip()
120
    else:
121
        get_label = lambda d: list(columns).index(d.split(pow)[0].strip())
122
123
    label_vector = []
124
    for descriptor in (item.split(mul) for item in label_vector_description):
125
126
        labels = map(get_label, descriptor)
127
        orders = map(get_power, descriptor)
128
129
        term = OrderedDict()
130
        for label, order in zip(labels, orders):
131
            term[label] = term.get(label, 0) + order # Sum repeat term powers.
132
133
        # Prevent uses of x^0 etc clogging up the label vector.
134
        valid_terms = [(l, o) for l, o in term.items() if o != 0]
135
        if not np.all(np.isfinite([o for l, o in valid_terms])):
136
            raise ValueError("non-finite power provided")
137
138
        if len(valid_terms) > 0:
139
            label_vector.append(valid_terms)
140
    
141
    if sum(map(len, label_vector)) == 0:
142
        raise ValueError("no valid terms provided")
143
144
    return label_vector
145
146
147
def human_readable_label_vector(label_vector, **kwargs):
148
    """
149
    Return a human-readable form of the label vector provided.
150
    """
151
152
    if not is_structured_label_vector(label_vector):
153
        raise TypeError("invalid label vector provided")
154
155
    theta = ["1"]
156
    for descriptor in label_vector:
157
        cross_terms = []
158
        for label, order in descriptor:
159
            if order == 0: continue
160
            cross_terms.append(
161
                "".join([str(label), "^{}".format(order) if order > 1 else ""]))
162
        
163
        term = " * ".join(cross_terms)
164
        format = "({0})" if len(cross_terms) > 1 else "{0}"
165
        theta.append(format.format(term))
166
167
    return " + ".join(theta)
168
        
169
170
def progressbar(iterable, message=None, size=100):
171
    """
172
    A progressbar.
173
174
    :param iterable:
175
        Some iterable to show progress for.
176
177
    :param message: [optional]
178
        A string message to show as the progressbar header.
179
180
    :param size: [optional]
181
        The size of the progressbar. If the size given is zero or negative,
182
        then no progressbar will be shown.
183
    """
184
185
    # Preparerise.
186
    t_init = time()
187
    count = len(iterable)
188
    def _update(i, t=None):
189
        if 0 >= size: return
190
        increment = max(1, int(count / 100))
191
        if i % increment == 0 or i in (0, count):
192
            sys.stdout.write("\r[{done}{not_done}] {percent:3.0f}%{t}".format(
193
                done="=" * int(i/increment),
194
                not_done=" " * int((count - i)/increment),
195
                percent=100. * i/count,
196
                t="" if t is None else " ({0:.0f}s)".format(t-t_init)))
197
            sys.stdout.flush()
198
199
    # Initialise.
200
    if size > 0:
201
        logger.info((message or "").rstrip())
202
        sys.stdout.flush()
203
204
    # Updaterise.
205
    for i, item in enumerate(iterable):
206
        yield item
207
        _update(i)
208
209
    # Finalise.
210
    if size > 0:
211
        _update(count, time())
212
        sys.stdout.write("\r\n")
213
        sys.stdout.flush()
214
215
216
def build_label_vector(labels, order, cross_term_order=0, **kwargs):
217
    """
218
    Build a label vector description.
219
220
    :param labels:
221
        The labels to use in describing the label vector.
222
223
    :param order:
224
        The maximum order of the terms (e.g., order 3 implies A^3 is a term).
225
226
    :param cross_term_order: [optional]
227
        The maximum order of the cross-terms (e.g., cross_term_order 2 implies
228
        A^2*B is a term).
229
230
    :param mul: [optional]
231
        The operator to use to represent multiplication in the description of 
232
        the label vector.
233
234
    :param pow: [optional]
235
        The operator to use to represent exponents in the description of the
236
        label vector.
237
238
    :returns:
239
        A human-readable form of the label vector.
240
    """
241
242
    sep = kwargs.pop("sep", "+")
243
    mul = kwargs.pop("mul", "*")
244
    pow = kwargs.pop("pow", "^")
245
246
    #I make no apologies: it's fun to code like this for short complex functions
247
    items = []
248
    for o in range(1, 1 + max(order, 1 + cross_term_order)):
249
        for t in map(Counter, combinations_with_replacement(labels, o)):
250
            # Python 2 and 3 behave differently here, so generate an ordered
251
            # dictionary based on sorting the keys.
252
            t = OrderedDict([(k, t[k]) for k in sorted(t.keys())])
253
            if len(t) == 1 and order >= max(t.values()) \
254
            or len(t) > 1 and cross_term_order >= max(t.values()):
255
                c = [pow.join([[l], [l, str(p)]][p > 1]) for l, p in t.items()]
256
                if c: items.append(mul.join(map(str, c)))
257
    return " {} ".format(sep).join(items)
258