Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

fuel.bin.CheckDirectoryAction   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 6
Duplicated Lines 0 %
Metric Value
dl 0
loc 6
rs 10
wmc 2

1 Method

Rating   Name   Duplication   Size   Complexity  
A CheckDirectoryAction.__call__() 0 5 2
1
#!/usr/bin/env python
2
"""Fuel dataset conversion utility."""
3
import argparse
4
import importlib
5
import os
6
import sys
7
8
import h5py
9
10
import fuel
11
from fuel import converters
12
from fuel.converters.base import MissingInputFiles
13
from fuel.datasets import H5PYDataset
14
15
16
class CheckDirectoryAction(argparse.Action):
17
    def __call__(self, parser, namespace, values, option_string=None):
18
        if os.path.isdir(values):
19
            setattr(namespace, self.dest, values)
20
        else:
21
            raise ValueError('{} is not a existing directory'.format(values))
22
23
24
def main(args=None):
25
    """Entry point for `fuel-convert` script.
26
27
    This function can also be imported and used from Python.
28
29
    Parameters
30
    ----------
31
    args : iterable, optional (default: None)
32
        A list of arguments that will be passed to Fuel's conversion
33
        utility. If this argument is not specified, `sys.argv[1:]` will
34
        be used.
35
36
    """
37
    built_in_datasets = dict(converters.all_converters)
38
    if fuel.config.extra_converters:
39
        for name in fuel.config.extra_converters:
40
            extra_datasets = dict(
41
                importlib.import_module(name).all_converters)
42
            if any(key in built_in_datasets for key in extra_datasets.keys()):
43
                raise ValueError('extra converters conflict in name with '
44
                                 'built-in converters')
45
            built_in_datasets.update(extra_datasets)
46
    parser = argparse.ArgumentParser(
47
        description='Conversion script for built-in datasets.')
48
    subparsers = parser.add_subparsers()
49
    parent_parser = argparse.ArgumentParser(add_help=False)
50
    parent_parser.add_argument(
51
        "-d", "--directory", help="directory in which input files reside",
52
        type=str, default=os.getcwd())
53
    convert_functions = {}
54
    for name, fill_subparser in built_in_datasets.items():
55
        subparser = subparsers.add_parser(
56
            name, parents=[parent_parser],
57
            help='Convert the {} dataset'.format(name))
58
        subparser.add_argument(
59
            "-o", "--output-directory", help="where to save the dataset",
60
            type=str, default=os.getcwd(), action=CheckDirectoryAction)
61
        subparser.add_argument(
62
            "-r", "--output_filename", help="new name of the created dataset",
63
            type=str, default=None)
64
        # Allows the parser to know which subparser was called.
65
        subparser.set_defaults(which_=name)
66
        convert_functions[name] = fill_subparser(subparser)
67
68
    args = parser.parse_args(args)
69
    args_dict = vars(args)
70
    if args_dict['output_filename'] is not None and\
71
        os.path.splitext(args_dict['output_filename'])[1] not in\
72
            ('.hdf5', '.hdf', '.h5'):
73
        args_dict['output_filename'] += '.hdf5'
74
    if args_dict['output_filename'] is None:
75
        args_dict.pop('output_filename')
76
77
    convert_function = convert_functions[args_dict.pop('which_')]
78
    try:
79
        output_paths = convert_function(**args_dict)
80
    except MissingInputFiles as e:
81
        intro = "The following required files were not found:\n"
82
        message = "\n".join([intro] + ["   * " + f for f in e.filenames])
83
        message += "\n\nDid you forget to run fuel-download?"
84
        parser.error(message)
85
86
    # Tag the newly-created file(s) with H5PYDataset version and command-line
87
    # options
88
    for output_path in output_paths:
89
        h5file = h5py.File(output_path, 'a')
90
        interface_version = H5PYDataset.interface_version.encode('utf-8')
91
        h5file.attrs['h5py_interface_version'] = interface_version
92
        fuel_convert_version = converters.__version__.encode('utf-8')
93
        h5file.attrs['fuel_convert_version'] = fuel_convert_version
94
        command = [os.path.basename(sys.argv[0])] + sys.argv[1:]
95
        h5file.attrs['fuel_convert_command'] = (
96
            ' '.join(command).encode('utf-8'))
97
        h5file.flush()
98
        h5file.close()
99
100
101
if __name__ == "__main__":
102
    main()
103