1 | #!/usr/bin/env python |
||
2 | # -*- coding: utf-8 -*- |
||
3 | |||
4 | """ A command line utility for The Cannon. """ |
||
5 | |||
6 | import argparse |
||
7 | import logging |
||
8 | import os |
||
9 | from collections import OrderedDict |
||
10 | from numpy import ceil, loadtxt, zeros, nan, diag, ones |
||
11 | from subprocess import check_output |
||
12 | from six.moves import cPickle as pickle |
||
13 | from tempfile import mkstemp |
||
14 | from time import sleep |
||
15 | |||
16 | |||
17 | def fit(model_filename, spectrum_filenames, threads, clobber, from_filename, |
||
18 | **kwargs): |
||
19 | """ |
||
20 | Fit a series of spectra. |
||
21 | """ |
||
22 | |||
23 | import AnniesLasso as tc |
||
24 | |||
25 | model = tc.load_model(model_filename, threads=threads) |
||
26 | logger = logging.getLogger("AnniesLasso") |
||
27 | assert model.is_trained |
||
28 | |||
29 | chunk_size = kwargs.pop("parallel_chunks", 1000) if threads > 1 else 1 |
||
30 | fluxes = [] |
||
31 | ivars = [] |
||
32 | output_filenames = [] |
||
33 | failures = 0 |
||
34 | |||
35 | fit_velocity = kwargs.pop("fit_velocity", False) |
||
36 | |||
37 | # MAGIC HACK |
||
38 | delete_meta_keys = ("fjac", ) # To save space... |
||
39 | initial_labels = loadtxt("initial_labels.txt") |
||
40 | |||
41 | if from_filename: |
||
42 | with open(spectrum_filenames[0], "r") as fp: |
||
43 | _ = list(map(str.strip, fp.readlines())) |
||
44 | spectrum_filenames = _ |
||
45 | |||
46 | output_suffix = kwargs.get("output_suffix", None) |
||
47 | output_suffix = "result" if output_suffix is None else str(output_suffix) |
||
48 | N = len(spectrum_filenames) |
||
49 | for i, filename in enumerate(spectrum_filenames): |
||
50 | logger.info("At spectrum {0}/{1}: {2}".format(i + 1, N, filename)) |
||
51 | |||
52 | basename, _ = os.path.splitext(filename) |
||
53 | output_filename = "-".join([basename, output_suffix]) + ".pkl" |
||
54 | |||
55 | if os.path.exists(output_filename) and not clobber: |
||
56 | logger.info("Output filename {} already exists and not clobbering."\ |
||
57 | .format(output_filename)) |
||
58 | continue |
||
59 | |||
60 | try: |
||
61 | with open(filename, "rb") as fp: |
||
62 | flux, ivar = pickle.load(fp) |
||
63 | fluxes.append(flux) |
||
64 | ivars.append(ivar) |
||
65 | |||
66 | output_filenames.append(output_filename) |
||
67 | |||
68 | except: |
||
69 | logger.exception("Error occurred loading {}".format(filename)) |
||
70 | failures += 1 |
||
71 | |||
72 | else: |
||
73 | View Code Duplication | if len(output_filenames) >= chunk_size: |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
74 | |||
75 | results, covs, metas = model.fit(fluxes, ivars, |
||
76 | initial_labels=initial_labels, model_redshift=fit_velocity, |
||
77 | full_output=True) |
||
78 | |||
79 | for result, cov, meta, output_filename \ |
||
80 | in zip(results, covs, metas, output_filenames): |
||
81 | |||
82 | for key in delete_meta_keys: |
||
83 | if key in meta: |
||
84 | del meta[key] |
||
85 | |||
86 | with open(output_filename, "wb") as fp: |
||
87 | pickle.dump((result, cov, meta), fp, 2) # For legacy. |
||
88 | logger.info("Saved output to {}".format(output_filename)) |
||
89 | |||
90 | del output_filenames[0:], fluxes[0:], ivars[0:] |
||
91 | |||
92 | |||
93 | View Code Duplication | if len(output_filenames) > 0: |
|
0 ignored issues
–
show
|
|||
94 | |||
95 | results, covs, metas = model.fit(fluxes, ivars, |
||
96 | initial_labels=initial_labels, model_redshift=fit_velocity, |
||
97 | full_output=True) |
||
98 | |||
99 | for result, cov, meta, output_filename \ |
||
100 | in zip(results, covs, metas, output_filenames): |
||
101 | |||
102 | for key in delete_meta_keys: |
||
103 | if key in meta: |
||
104 | del meta[key] |
||
105 | |||
106 | with open(output_filename, "wb") as fp: |
||
107 | pickle.dump((result, cov, meta), fp, 2) # For legacy. |
||
108 | logger.info("Saved output to {}".format(output_filename)) |
||
109 | |||
110 | del output_filenames[0:], fluxes[0:], ivars[0:] |
||
111 | |||
112 | |||
113 | logger.info("Number of failures: {}".format(failures)) |
||
114 | logger.info("Number of successes: {}".format(N - failures)) |
||
115 | |||
116 | return None |
||
117 | |||
118 | |||
119 | |||
120 | |||
121 | def join_results(output_filename, result_filenames, model_filename=None, |
||
122 | from_filename=False, clobber=False, errors=False, cov=False, **kwargs): |
||
123 | """ |
||
124 | Join the test results from multiple files into a single table file. |
||
125 | """ |
||
126 | |||
127 | import AnniesLasso as tc |
||
128 | from astropy.table import Table, TableColumns |
||
129 | |||
130 | meta_keys = kwargs.pop("meta_keys", {}) |
||
131 | meta_keys.update({ |
||
132 | "chi_sq": nan, |
||
133 | "r_chi_sq": nan, |
||
134 | "snr": nan, |
||
135 | # "redshift": nan, |
||
136 | }) |
||
137 | |||
138 | logger = logging.getLogger("AnniesLasso") |
||
139 | |||
140 | # Does the output filename already exist? |
||
141 | if os.path.exists(output_filename) and not clobber: |
||
142 | logger.info("Output filename {} already exists and not clobbering."\ |
||
143 | .format(output_filename)) |
||
144 | return None |
||
145 | |||
146 | if from_filename: |
||
147 | with open(result_filenames[0], "r") as fp: |
||
148 | _ = list(map(str.strip, fp.readlines())) |
||
149 | result_filenames = _ |
||
150 | |||
151 | # We might need the label names from the model. |
||
152 | if model_filename is not None: |
||
153 | model = tc.load_model(model_filename) |
||
154 | assert model.is_trained |
||
155 | label_names = model.vectorizer.label_names |
||
156 | logger.warn( |
||
157 | "Results produced from newer models do not need a model_filename "\ |
||
158 | "to be specified when joining results.") |
||
159 | |||
160 | else: |
||
161 | with open(result_filenames[0], "rb") as fp: |
||
162 | contents = pickle.load(fp) |
||
163 | if "label_names" not in contents[-1]: |
||
164 | raise ValueError( |
||
165 | "cannot find label names; please provide the model used "\ |
||
166 | "to produce these results") |
||
167 | label_names = contents[-1]["label_names"] |
||
168 | |||
169 | |||
170 | # Load results from each file. |
||
171 | failed = [] |
||
172 | N = len(result_filenames) |
||
173 | |||
174 | # Create an ordered dictionary of lists for all the data. |
||
175 | data_dict = OrderedDict([("FILENAME", [])]) |
||
176 | for label_name in label_names: |
||
177 | data_dict[label_name] = [] |
||
178 | |||
179 | if errors: |
||
180 | for label_name in label_names: |
||
181 | data_dict["E_{}".format(label_name)] = [] |
||
182 | |||
183 | if cov: |
||
184 | data_dict["COV"] = [] |
||
185 | |||
186 | for key in meta_keys: |
||
187 | data_dict[key] = [] |
||
188 | |||
189 | # Iterate over all the result filenames |
||
190 | for i, filename in enumerate(result_filenames): |
||
191 | logger.info("{}/{}: {}".format(i + 1, N, filename)) |
||
192 | |||
193 | if not os.path.exists(filename): |
||
194 | logger.warn("Path {} does not exist. Continuing..".format(filename)) |
||
195 | failed.append(filename) |
||
196 | continue |
||
197 | |||
198 | with open(filename, "rb") as fp: |
||
199 | contents = pickle.load(fp) |
||
200 | |||
201 | assert len(contents) == 3, "You are using some old school version!" |
||
202 | |||
203 | labels, Sigma, meta = contents |
||
204 | |||
205 | if Sigma is None: |
||
206 | Sigma = nan * ones((labels.size, labels.size)) |
||
207 | |||
208 | result = [filename] + list(labels) |
||
209 | if errors: |
||
210 | result.extend(diag(Sigma)**0.5) |
||
211 | if cov: |
||
212 | result.append(Sigma) |
||
213 | result += [meta.get(k, v) for k, v in meta_keys.items()] |
||
214 | |||
215 | for key, value in zip(data_dict.keys(), result): |
||
216 | data_dict[key].append(value) |
||
217 | |||
218 | # Warn of any failures. |
||
219 | if failed: |
||
220 | logger.warn( |
||
221 | "The following {} result file(s) could not be found: \n{}".format( |
||
222 | len(failed), "\n".join(failed))) |
||
223 | |||
224 | # Construct the table. |
||
225 | table = Table(TableColumns(data_dict)) |
||
226 | table.write(output_filename, overwrite=clobber) |
||
227 | logger.info("Written to {}".format(output_filename)) |
||
228 | |||
229 | |||
230 | |||
231 | |||
232 | def main(): |
||
233 | """ |
||
234 | The main command line interpreter. This is the console script entry point. |
||
235 | """ |
||
236 | |||
237 | # Create the main parser. |
||
238 | parser = argparse.ArgumentParser( |
||
239 | description="The Cannon", epilog="http://TheCannon.io") |
||
240 | |||
241 | # Create parent parser. |
||
242 | parent_parser = argparse.ArgumentParser(add_help=False) |
||
243 | parent_parser.add_argument("-v", "--verbose", |
||
244 | dest="verbose", action="store_true", default=False, |
||
245 | help="Verbose logging mode.") |
||
246 | parent_parser.add_argument("-t", "--threads", |
||
247 | dest="threads", type=int, default=1, |
||
248 | help="The number of threads to use.") |
||
249 | |||
250 | # Allow for multiple actions. |
||
251 | subparsers = parser.add_subparsers(title="action", dest="action", |
||
252 | description="Specify the action to perform.") |
||
253 | |||
254 | # Fitting parser. |
||
255 | fit_parser = subparsers.add_parser("fit", parents=[parent_parser], |
||
256 | help="Fit stacked spectra using a trained model.") |
||
257 | fit_parser.add_argument("model_filename", type=str, |
||
258 | help="The path of a trained Cannon model.") |
||
259 | fit_parser.add_argument("spectrum_filenames", nargs="+", type=str, |
||
260 | help="Paths of spectra to fit.") |
||
261 | fit_parser.add_argument("--parallel-chunks", dest="parallel_chunks", |
||
262 | type=int, default=1000, help="The number of spectra to fit in a chunk.") |
||
263 | fit_parser.add_argument("--clobber", dest="clobber", default=False, |
||
264 | action="store_true", help="Overwrite existing output files.") |
||
265 | fit_parser.add_argument( |
||
266 | "--output-suffix", dest="output_suffix", type=str, |
||
267 | help="A string suffix that will be added to the spectrum filenames "\ |
||
268 | "when creating the result filename") |
||
269 | fit_parser.add_argument("--from-filename", dest="from_filename", |
||
270 | action="store_true", default=False, help="Read spectrum filenames from file") |
||
271 | fit_parser.set_defaults(func=fit) |
||
272 | |||
273 | |||
274 | # Join results parser. |
||
275 | join_parser = subparsers.add_parser("join", parents=[parent_parser], |
||
276 | help="Join results from individual stars into a single table.") |
||
277 | join_parser.add_argument("output_filename", type=str, |
||
278 | help="The path to write the output filename.") |
||
279 | join_parser.add_argument("result_filenames", nargs="+", type=str, |
||
280 | help="Paths of result files to include.") |
||
281 | join_parser.add_argument("--from-filename", |
||
282 | dest="from_filename", action="store_true", default=False, |
||
283 | help="Read result filenames from a file.") |
||
284 | join_parser.add_argument( |
||
285 | "--errors", dest="errors", default=False, action="store_true", |
||
286 | help="Include formal errors in destination table.") |
||
287 | join_parser.add_argument( |
||
288 | "--cov", dest="cov", default=False, action="store_true", |
||
289 | help="Include covariance matrix in destination table.") |
||
290 | join_parser.add_argument( |
||
291 | "--clobber", dest="clobber", default=False, action="store_true", |
||
292 | help="Ovewrite an existing table file.") |
||
293 | |||
294 | join_parser.set_defaults(func=join_results) |
||
295 | |||
296 | # Parse the arguments and take care of any top-level arguments. |
||
297 | args = parser.parse_args() |
||
298 | if args.action is None: return |
||
299 | |||
300 | logger = logging.getLogger("AnniesLasso") |
||
301 | if args.verbose: |
||
302 | logger.setLevel(logging.DEBUG) |
||
303 | |||
304 | # Do things. |
||
305 | return args.func(**args.__dict__) |
||
306 | |||
307 | |||
308 | if __name__ == "__main__": |
||
309 | |||
310 | """ |
||
311 | Usage examples: |
||
312 | # tc train model.pickle --condor --chunks 100 |
||
313 | # tc train model.pickle --threads 8 |
||
314 | # tc join model.pickle --from-filename files |
||
315 | |||
316 | """ |
||
317 | _ = main() |
||
318 |