join_results()   F
last analyzed

Complexity

Conditions 24

Size

Total Lines 107

Duplication

Lines 0
Ratio 0 %

Importance

Changes 4
Bugs 1 Features 0
Metric Value
cc 24
c 4
b 1
f 0
dl 0
loc 107
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like join_results() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
""" A command line utility for The Cannon. """
5
6
import argparse
7
import logging
8
import os
9
from collections import OrderedDict
10
from numpy import ceil, loadtxt, zeros, nan, diag, ones
11
from subprocess import check_output
12
from six.moves import cPickle as pickle
13
from tempfile import mkstemp
14
from time import sleep
15
16
17
def fit(model_filename, spectrum_filenames, threads, clobber, from_filename,
18
    **kwargs):
19
    """
20
    Fit a series of spectra.
21
    """
22
23
    import AnniesLasso as tc
24
25
    model = tc.load_model(model_filename, threads=threads)
26
    logger = logging.getLogger("AnniesLasso")
27
    assert model.is_trained
28
29
    chunk_size = kwargs.pop("parallel_chunks", 1000) if threads > 1 else 1
30
    fluxes = []
31
    ivars = []
32
    output_filenames = []
33
    failures = 0
34
35
    fit_velocity = kwargs.pop("fit_velocity", False)
36
37
    # MAGIC HACK
38
    delete_meta_keys = ("fjac", ) # To save space...
39
    initial_labels = loadtxt("initial_labels.txt")
40
41
    if from_filename:
42
        with open(spectrum_filenames[0], "r") as fp:
43
            _ = list(map(str.strip, fp.readlines()))
44
        spectrum_filenames = _
45
46
    output_suffix = kwargs.get("output_suffix", None)
47
    output_suffix = "result" if output_suffix is None else str(output_suffix)
48
    N = len(spectrum_filenames)
49
    for i, filename in enumerate(spectrum_filenames):
50
        logger.info("At spectrum {0}/{1}: {2}".format(i + 1, N, filename))
51
52
        basename, _ = os.path.splitext(filename)
53
        output_filename = "-".join([basename, output_suffix]) + ".pkl"
54
        
55
        if os.path.exists(output_filename) and not clobber:
56
            logger.info("Output filename {} already exists and not clobbering."\
57
                .format(output_filename))
58
            continue
59
60
        try:
61
            with open(filename, "rb") as fp:
62
                flux, ivar = pickle.load(fp)
63
                fluxes.append(flux)
64
                ivars.append(ivar)
65
66
            output_filenames.append(output_filename)
67
68
        except:
69
            logger.exception("Error occurred loading {}".format(filename))
70
            failures += 1
71
72
        else:
73 View Code Duplication
            if len(output_filenames) >= chunk_size:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
74
                
75
                results, covs, metas = model.fit(fluxes, ivars,
76
                    initial_labels=initial_labels, model_redshift=fit_velocity,
77
                    full_output=True)
78
79
                for result, cov, meta, output_filename \
80
                in zip(results, covs, metas, output_filenames):
81
82
                    for key in delete_meta_keys:
83
                        if key in meta:
84
                            del meta[key]
85
86
                    with open(output_filename, "wb") as fp:
87
                        pickle.dump((result, cov, meta), fp, 2) # For legacy.
88
                    logger.info("Saved output to {}".format(output_filename))
89
                
90
                del output_filenames[0:], fluxes[0:], ivars[0:]
91
92
93 View Code Duplication
    if len(output_filenames) > 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
94
        
95
        results, covs, metas = model.fit(fluxes, ivars, 
96
            initial_labels=initial_labels, model_redshift=fit_velocity,
97
            full_output=True)
98
99
        for result, cov, meta, output_filename \
100
        in zip(results, covs, metas, output_filenames):
101
102
            for key in delete_meta_keys:
103
                if key in meta:
104
                    del meta[key]
105
106
            with open(output_filename, "wb") as fp:
107
                pickle.dump((result, cov, meta), fp, 2) # For legacy.
108
            logger.info("Saved output to {}".format(output_filename))
109
        
110
        del output_filenames[0:], fluxes[0:], ivars[0:]
111
112
113
    logger.info("Number of failures: {}".format(failures))
114
    logger.info("Number of successes: {}".format(N - failures))
115
116
    return None
117
118
119
120
121
def join_results(output_filename, result_filenames, model_filename=None, 
122
    from_filename=False, clobber=False, errors=False, cov=False, **kwargs):
123
    """
124
    Join the test results from multiple files into a single table file.
125
    """
126
127
    import AnniesLasso as tc
128
    from astropy.table import Table, TableColumns
129
130
    meta_keys = kwargs.pop("meta_keys", {})
131
    meta_keys.update({
132
        "chi_sq": nan,
133
        "r_chi_sq": nan,
134
        "snr": nan,
135
    #    "redshift": nan,
136
    })
137
138
    logger = logging.getLogger("AnniesLasso")
139
140
    # Does the output filename already exist?
141
    if os.path.exists(output_filename) and not clobber:
142
        logger.info("Output filename {} already exists and not clobbering."\
143
            .format(output_filename))
144
        return None
145
146
    if from_filename:
147
        with open(result_filenames[0], "r") as fp:
148
            _ = list(map(str.strip, fp.readlines()))
149
        result_filenames = _
150
151
    # We might need the label names from the model.
152
    if model_filename is not None:
153
        model = tc.load_model(model_filename)
154
        assert model.is_trained
155
        label_names = model.vectorizer.label_names
156
        logger.warn(
157
            "Results produced from newer models do not need a model_filename "\
158
            "to be specified when joining results.")
159
160
    else:
161
        with open(result_filenames[0], "rb") as fp:
162
            contents = pickle.load(fp)
163
            if "label_names" not in contents[-1]:
164
                raise ValueError(
165
                    "cannot find label names; please provide the model used "\
166
                    "to produce these results")
167
            label_names = contents[-1]["label_names"]
168
169
170
    # Load results from each file.
171
    failed = []
172
    N = len(result_filenames)
173
174
    # Create an ordered dictionary of lists for all the data.
175
    data_dict = OrderedDict([("FILENAME", [])])
176
    for label_name in label_names:
177
        data_dict[label_name] = []
178
        
179
    if errors:
180
        for label_name in label_names:
181
            data_dict["E_{}".format(label_name)] = []
182
    
183
    if cov:
184
        data_dict["COV"] = []
185
186
    for key in meta_keys:
187
        data_dict[key] = []
188
    
189
    # Iterate over all the result filenames
190
    for i, filename in enumerate(result_filenames):
191
        logger.info("{}/{}: {}".format(i + 1, N, filename))
192
193
        if not os.path.exists(filename):
194
            logger.warn("Path {} does not exist. Continuing..".format(filename))
195
            failed.append(filename)
196
            continue
197
198
        with open(filename, "rb") as fp:
199
            contents = pickle.load(fp)
200
201
        assert len(contents) == 3, "You are using some old school version!"
202
        
203
        labels, Sigma, meta = contents
204
205
        if Sigma is None:
206
            Sigma = nan * ones((labels.size, labels.size))
207
208
        result = [filename] + list(labels) 
209
        if errors:
210
            result.extend(diag(Sigma)**0.5) 
211
        if cov:
212
            result.append(Sigma)
213
        result += [meta.get(k, v) for k, v in meta_keys.items()]
214
215
        for key, value in zip(data_dict.keys(), result):
216
            data_dict[key].append(value)
217
218
    # Warn of any failures.
219
    if failed:
220
        logger.warn(
221
            "The following {} result file(s) could not be found: \n{}".format(
222
                len(failed), "\n".join(failed)))
223
224
    # Construct the table.
225
    table = Table(TableColumns(data_dict))
226
    table.write(output_filename, overwrite=clobber)
227
    logger.info("Written to {}".format(output_filename))
228
    
229
230
231
232
def main():
233
    """
234
    The main command line interpreter. This is the console script entry point.
235
    """
236
237
    # Create the main parser.
238
    parser = argparse.ArgumentParser(
239
        description="The Cannon", epilog="http://TheCannon.io")
240
241
    # Create parent parser.
242
    parent_parser = argparse.ArgumentParser(add_help=False)
243
    parent_parser.add_argument("-v", "--verbose",
244
        dest="verbose", action="store_true", default=False, 
245
        help="Verbose logging mode.")
246
    parent_parser.add_argument("-t", "--threads",
247
        dest="threads", type=int, default=1,
248
        help="The number of threads to use.")
249
    
250
    # Allow for multiple actions.
251
    subparsers = parser.add_subparsers(title="action", dest="action",
252
        description="Specify the action to perform.")
253
254
    # Fitting parser.
255
    fit_parser = subparsers.add_parser("fit", parents=[parent_parser],
256
        help="Fit stacked spectra using a trained model.")
257
    fit_parser.add_argument("model_filename", type=str,
258
        help="The path of a trained Cannon model.")
259
    fit_parser.add_argument("spectrum_filenames", nargs="+", type=str,
260
        help="Paths of spectra to fit.")
261
    fit_parser.add_argument("--parallel-chunks", dest="parallel_chunks",
262
        type=int, default=1000, help="The number of spectra to fit in a chunk.")
263
    fit_parser.add_argument("--clobber", dest="clobber", default=False,
264
        action="store_true", help="Overwrite existing output files.")
265
    fit_parser.add_argument(
266
        "--output-suffix", dest="output_suffix", type=str,
267
        help="A string suffix that will be added to the spectrum filenames "\
268
             "when creating the result filename")
269
    fit_parser.add_argument("--from-filename", dest="from_filename",
270
        action="store_true", default=False, help="Read spectrum filenames from file")
271
    fit_parser.set_defaults(func=fit)
272
273
274
    # Join results parser.
275
    join_parser = subparsers.add_parser("join", parents=[parent_parser],
276
        help="Join results from individual stars into a single table.")
277
    join_parser.add_argument("output_filename", type=str,
278
        help="The path to write the output filename.")
279
    join_parser.add_argument("result_filenames", nargs="+", type=str,
280
        help="Paths of result files to include.")
281
    join_parser.add_argument("--from-filename", 
282
        dest="from_filename", action="store_true", default=False,
283
        help="Read result filenames from a file.")
284
    join_parser.add_argument(
285
        "--errors", dest="errors", default=False, action="store_true", 
286
        help="Include formal errors in destination table.")
287
    join_parser.add_argument(
288
        "--cov", dest="cov", default=False, action="store_true", 
289
        help="Include covariance matrix in destination table.")
290
    join_parser.add_argument(
291
        "--clobber", dest="clobber", default=False, action="store_true", 
292
        help="Ovewrite an existing table file.")
293
294
    join_parser.set_defaults(func=join_results)
295
296
    # Parse the arguments and take care of any top-level arguments.
297
    args = parser.parse_args()
298
    if args.action is None: return
299
300
    logger = logging.getLogger("AnniesLasso")
301
    if args.verbose:
302
        logger.setLevel(logging.DEBUG)
303
304
    # Do things.
305
    return args.func(**args.__dict__)
306
307
308
if __name__ == "__main__":
309
310
    """
311
    Usage examples:
312
    # tc train model.pickle --condor --chunks 100
313
    # tc train model.pickle --threads 8
314
    # tc join model.pickle --from-filename files
315
316
    """
317
    _ = main()
318