1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
""" |
5
|
|
|
General utility functions. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
__all__ = ["label_vector", "progressbar", "short_hash"] |
9
|
|
|
|
10
|
|
|
import logging |
11
|
|
|
import numpy as np |
12
|
|
|
import sys |
13
|
|
|
from time import time |
14
|
|
|
from collections import (Counter, Iterable, OrderedDict) |
15
|
|
|
from hashlib import md5 |
16
|
|
|
from itertools import combinations_with_replacement |
17
|
|
|
from six import string_types |
18
|
|
|
|
19
|
|
|
logger = logging.getLogger(__name__) |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
def short_hash(contents): |
23
|
|
|
""" |
24
|
|
|
Return a short hash string of some iterable content. |
25
|
|
|
|
26
|
|
|
:param contents: |
27
|
|
|
The contents to calculate a hash for. |
28
|
|
|
|
29
|
|
|
:returns: |
30
|
|
|
A concatenated string of 10-character length hashes for all items in the |
31
|
|
|
contents provided. |
32
|
|
|
""" |
33
|
|
|
if not isinstance(contents, Iterable): contents = [contents] |
34
|
|
|
return "".join([str(md5(str(item).encode("utf-8")).hexdigest())[:10] \ |
35
|
|
|
for item in contents]) |
36
|
|
|
|
37
|
|
|
|
38
|
|
|
def is_structured_label_vector(label_vector): |
39
|
|
|
""" |
40
|
|
|
Return whether the provided label vector is structured correctly. |
41
|
|
|
""" |
42
|
|
|
|
43
|
|
|
if not isinstance(label_vector, (list, tuple)): |
44
|
|
|
return False |
45
|
|
|
|
46
|
|
|
for descriptor in label_vector: |
47
|
|
|
if not isinstance(descriptor, (list, tuple)): |
48
|
|
|
return False |
49
|
|
|
|
50
|
|
|
for term in descriptor: |
51
|
|
|
if not isinstance(term, (list, tuple)) \ |
52
|
|
|
or len(term) != 2 \ |
53
|
|
|
or not isinstance(term[-1], (int, float)): |
54
|
|
|
return False |
55
|
|
|
|
56
|
|
|
if len(label_vector) == 0 or sum(map(len, label_vector)) == 0: |
57
|
|
|
return False |
58
|
|
|
|
59
|
|
|
return True |
60
|
|
|
|
61
|
|
|
|
62
|
|
|
def parse_label_vector(label_vector_description, columns=None, **kwargs): |
63
|
|
|
""" |
64
|
|
|
Return a structured form of a label vector from unstructured, |
65
|
|
|
human-readable input. |
66
|
|
|
|
67
|
|
|
:param label_vector_description: |
68
|
|
|
A human-readable or structured form of a label vector. |
69
|
|
|
|
70
|
|
|
:type label_vector_description: |
71
|
|
|
str or list |
72
|
|
|
|
73
|
|
|
:param columns: [optional] |
74
|
|
|
If `columns` are provided, instead of text columns being provided as the |
75
|
|
|
output parameter, the corresponding index location in `column` will be |
76
|
|
|
given. |
77
|
|
|
|
78
|
|
|
:returns: |
79
|
|
|
A structured form of the label vector as a multi-level list. |
80
|
|
|
|
81
|
|
|
|
82
|
|
|
:Example: |
83
|
|
|
|
84
|
|
|
>>> parse_label_vector("Teff^4 + logg*Teff^3 + feh + feh^0*Teff") |
85
|
|
|
[ |
86
|
|
|
[ |
87
|
|
|
("Teff", 4), |
88
|
|
|
], |
89
|
|
|
[ |
90
|
|
|
("logg", 1), |
91
|
|
|
("Teff", 3) |
92
|
|
|
], |
93
|
|
|
[ |
94
|
|
|
("feh", 1), |
95
|
|
|
], |
96
|
|
|
[ |
97
|
|
|
("feh", 0), |
98
|
|
|
("Teff", 1) |
99
|
|
|
] |
100
|
|
|
] |
101
|
|
|
""" |
102
|
|
|
|
103
|
|
|
if is_structured_label_vector(label_vector_description): |
104
|
|
|
return label_vector_description |
105
|
|
|
|
106
|
|
|
# Allow for custom characters, but don't advertise it. |
107
|
|
|
# (Astronomers have bad enough habits already.) |
108
|
|
|
kwds = dict(zip(("sep", "mul", "pow"), "+*^")) |
109
|
|
|
kwds.update(kwargs) |
110
|
|
|
sep, mul, pow = (kwds[k] for k in ("sep", "mul", "pow")) |
111
|
|
|
|
112
|
|
|
if isinstance(label_vector_description, string_types): |
113
|
|
|
label_vector_description = label_vector_description.split(sep) |
114
|
|
|
label_vector_description = map(str.strip, label_vector_description) |
115
|
|
|
|
116
|
|
|
# Functions to parse the parameter (or index) and order for each term. |
117
|
|
|
get_power = lambda t: float(t.split(pow)[1].strip()) if pow in t else 1 |
118
|
|
|
if columns is None: |
119
|
|
|
get_label = lambda d: d.split(pow)[0].strip() |
120
|
|
|
else: |
121
|
|
|
get_label = lambda d: list(columns).index(d.split(pow)[0].strip()) |
122
|
|
|
|
123
|
|
|
label_vector = [] |
124
|
|
|
for descriptor in (item.split(mul) for item in label_vector_description): |
125
|
|
|
|
126
|
|
|
labels = map(get_label, descriptor) |
127
|
|
|
orders = map(get_power, descriptor) |
128
|
|
|
|
129
|
|
|
term = OrderedDict() |
130
|
|
|
for label, order in zip(labels, orders): |
131
|
|
|
term[label] = term.get(label, 0) + order # Sum repeat term powers. |
132
|
|
|
|
133
|
|
|
# Prevent uses of x^0 etc clogging up the label vector. |
134
|
|
|
valid_terms = [(l, o) for l, o in term.items() if o != 0] |
135
|
|
|
if not np.all(np.isfinite([o for l, o in valid_terms])): |
136
|
|
|
raise ValueError("non-finite power provided") |
137
|
|
|
|
138
|
|
|
if len(valid_terms) > 0: |
139
|
|
|
label_vector.append(valid_terms) |
140
|
|
|
|
141
|
|
|
if sum(map(len, label_vector)) == 0: |
142
|
|
|
raise ValueError("no valid terms provided") |
143
|
|
|
|
144
|
|
|
return label_vector |
145
|
|
|
|
146
|
|
|
|
147
|
|
|
def human_readable_label_vector(label_vector, **kwargs): |
148
|
|
|
""" |
149
|
|
|
Return a human-readable form of the label vector provided. |
150
|
|
|
""" |
151
|
|
|
|
152
|
|
|
if not is_structured_label_vector(label_vector): |
153
|
|
|
raise TypeError("invalid label vector provided") |
154
|
|
|
|
155
|
|
|
theta = ["1"] |
156
|
|
|
for descriptor in label_vector: |
157
|
|
|
cross_terms = [] |
158
|
|
|
for label, order in descriptor: |
159
|
|
|
if order == 0: continue |
160
|
|
|
cross_terms.append( |
161
|
|
|
"".join([str(label), "^{}".format(order) if order > 1 else ""])) |
162
|
|
|
|
163
|
|
|
term = " * ".join(cross_terms) |
164
|
|
|
format = "({0})" if len(cross_terms) > 1 else "{0}" |
165
|
|
|
theta.append(format.format(term)) |
166
|
|
|
|
167
|
|
|
return " + ".join(theta) |
168
|
|
|
|
169
|
|
|
|
170
|
|
|
def progressbar(iterable, message=None, size=100): |
171
|
|
|
""" |
172
|
|
|
A progressbar. |
173
|
|
|
|
174
|
|
|
:param iterable: |
175
|
|
|
Some iterable to show progress for. |
176
|
|
|
|
177
|
|
|
:param message: [optional] |
178
|
|
|
A string message to show as the progressbar header. |
179
|
|
|
|
180
|
|
|
:param size: [optional] |
181
|
|
|
The size of the progressbar. If the size given is zero or negative, |
182
|
|
|
then no progressbar will be shown. |
183
|
|
|
""" |
184
|
|
|
|
185
|
|
|
# Preparerise. |
186
|
|
|
t_init = time() |
187
|
|
|
count = len(iterable) |
188
|
|
|
def _update(i, t=None): |
189
|
|
|
if 0 >= size: return |
190
|
|
|
increment = max(1, int(count / 100)) |
191
|
|
|
if i % increment == 0 or i in (0, count): |
192
|
|
|
sys.stdout.write("\r[{done}{not_done}] {percent:3.0f}%{t}".format( |
193
|
|
|
done="=" * int(i/increment), |
194
|
|
|
not_done=" " * int((count - i)/increment), |
195
|
|
|
percent=100. * i/count, |
196
|
|
|
t="" if t is None else " ({0:.0f}s)".format(t-t_init))) |
197
|
|
|
sys.stdout.flush() |
198
|
|
|
|
199
|
|
|
# Initialise. |
200
|
|
|
if size > 0: |
201
|
|
|
logger.info((message or "").rstrip()) |
202
|
|
|
sys.stdout.flush() |
203
|
|
|
|
204
|
|
|
# Updaterise. |
205
|
|
|
for i, item in enumerate(iterable): |
206
|
|
|
yield item |
207
|
|
|
_update(i) |
208
|
|
|
|
209
|
|
|
# Finalise. |
210
|
|
|
if size > 0: |
211
|
|
|
_update(count, time()) |
212
|
|
|
sys.stdout.write("\r\n") |
213
|
|
|
sys.stdout.flush() |
214
|
|
|
|
215
|
|
|
|
216
|
|
|
def build_label_vector(labels, order, cross_term_order=0, **kwargs): |
217
|
|
|
""" |
218
|
|
|
Build a label vector description. |
219
|
|
|
|
220
|
|
|
:param labels: |
221
|
|
|
The labels to use in describing the label vector. |
222
|
|
|
|
223
|
|
|
:param order: |
224
|
|
|
The maximum order of the terms (e.g., order 3 implies A^3 is a term). |
225
|
|
|
|
226
|
|
|
:param cross_term_order: [optional] |
227
|
|
|
The maximum order of the cross-terms (e.g., cross_term_order 2 implies |
228
|
|
|
A^2*B is a term). |
229
|
|
|
|
230
|
|
|
:param mul: [optional] |
231
|
|
|
The operator to use to represent multiplication in the description of |
232
|
|
|
the label vector. |
233
|
|
|
|
234
|
|
|
:param pow: [optional] |
235
|
|
|
The operator to use to represent exponents in the description of the |
236
|
|
|
label vector. |
237
|
|
|
|
238
|
|
|
:returns: |
239
|
|
|
A human-readable form of the label vector. |
240
|
|
|
""" |
241
|
|
|
|
242
|
|
|
sep = kwargs.pop("sep", "+") |
243
|
|
|
mul = kwargs.pop("mul", "*") |
244
|
|
|
pow = kwargs.pop("pow", "^") |
245
|
|
|
|
246
|
|
|
#I make no apologies: it's fun to code like this for short complex functions |
247
|
|
|
items = [] |
248
|
|
|
for o in range(1, 1 + max(order, 1 + cross_term_order)): |
249
|
|
|
for t in map(Counter, combinations_with_replacement(labels, o)): |
250
|
|
|
# Python 2 and 3 behave differently here, so generate an ordered |
251
|
|
|
# dictionary based on sorting the keys. |
252
|
|
|
t = OrderedDict([(k, t[k]) for k in sorted(t.keys())]) |
253
|
|
|
if len(t) == 1 and order >= max(t.values()) \ |
254
|
|
|
or len(t) > 1 and cross_term_order >= max(t.values()): |
255
|
|
|
c = [pow.join([[l], [l, str(p)]][p > 1]) for l, p in t.items()] |
256
|
|
|
if c: items.append(mul.join(map(str, c))) |
257
|
|
|
return " {} ".format(sep).join(items) |
258
|
|
|
|