Completed
Push — master ( f31f72...51e8f0 )
by Bart
26s
created

fuel/converters/base.py (1 issue)

1
import os
2
import sys
3
from contextlib import contextmanager
4
from six import wraps
0 ignored issues
show
The name wraps does not seem to exist in module six.
Loading history...
5
6
import numpy
7
from progressbar import (ProgressBar, Percentage, Bar, ETA)
8
9
from fuel.datasets import H5PYDataset
10
from ..exceptions import MissingInputFiles
11
12
13
def check_exists(required_files):
14
    """Decorator that checks if required files exist before running.
15
16
    Parameters
17
    ----------
18
    required_files : list of str
19
        A list of strings indicating the filenames of regular files
20
        (not directories) that should be found in the input directory
21
        (which is the first argument to the wrapped function).
22
23
    Returns
24
    -------
25
    wrapper : function
26
        A function that takes a function and returns a wrapped function.
27
        The function returned by `wrapper` will include input file
28
        existence verification.
29
30
    Notes
31
    -----
32
    Assumes that the directory in which to find the input files is
33
    provided as the first argument, with the argument name `directory`.
34
35
    """
36
    def function_wrapper(f):
37
        @wraps(f)
38
        def wrapped(directory, *args, **kwargs):
39
            missing = []
40
            for filename in required_files:
41
                if not os.path.isfile(os.path.join(directory, filename)):
42
                    missing.append(filename)
43
            if len(missing) > 0:
44
                raise MissingInputFiles('Required files missing', missing)
45
            return f(directory, *args, **kwargs)
46
        return wrapped
47
    return function_wrapper
48
49
50
def fill_hdf5_file(h5file, data):
51
    """Fills an HDF5 file in a H5PYDataset-compatible manner.
52
53
    Parameters
54
    ----------
55
    h5file : :class:`h5py.File`
56
        File handle for an HDF5 file.
57
    data : tuple of tuple
58
        One element per split/source pair. Each element consists of a
59
        tuple of (split_name, source_name, data_array, comment), where
60
61
        * 'split_name' is a string identifier for the split name
62
        * 'source_name' is a string identifier for the source name
63
        * 'data_array' is a :class:`numpy.ndarray` containing the data
64
          for this split/source pair
65
        * 'comment' is a comment string for the split/source pair
66
67
        The 'comment' element can optionally be omitted.
68
69
    """
70
    # Check that all sources for a split have the same length
71
    split_names = set(split_tuple[0] for split_tuple in data)
72
    for name in split_names:
73
        lengths = [len(split_tuple[2]) for split_tuple in data
74
                   if split_tuple[0] == name]
75
        if not all(l == lengths[0] for l in lengths):
76
            raise ValueError("split '{}' has sources that ".format(name) +
77
                             "vary in length")
78
79
    # Initialize split dictionary
80
    split_dict = dict([(split_name, {}) for split_name in split_names])
81
82
    # Compute total source lengths and check that splits have the same dtype
83
    # across a source
84
    source_names = set(split_tuple[1] for split_tuple in data)
85
    for name in source_names:
86
        splits = [s for s in data if s[1] == name]
87
        indices = numpy.cumsum([0] + [len(s[2]) for s in splits])
88
        if not all(s[2].dtype == splits[0][2].dtype for s in splits):
89
            raise ValueError("source '{}' has splits that ".format(name) +
90
                             "vary in dtype")
91
        if not all(s[2].shape[1:] == splits[0][2].shape[1:] for s in splits):
92
            raise ValueError("source '{}' has splits that ".format(name) +
93
                             "vary in shapes")
94
        dataset = h5file.create_dataset(
95
            name, (sum(len(s[2]) for s in splits),) + splits[0][2].shape[1:],
96
            dtype=splits[0][2].dtype)
97
        dataset[...] = numpy.concatenate([s[2] for s in splits], axis=0)
98
        for i, j, s in zip(indices[:-1], indices[1:], splits):
99
            if len(s) == 4:
100
                split_dict[s[0]][name] = (i, j, None, s[3])
101
            else:
102
                split_dict[s[0]][name] = (i, j)
103
    h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)
104
105
106
@contextmanager
107
def progress_bar(name, maxval, prefix='Converting'):
108
    """Manages a progress bar for a conversion.
109
110
    Parameters
111
    ----------
112
    name : str
113
        Name of the file being converted.
114
    maxval : int
115
        Total number of steps for the conversion.
116
117
    """
118
    widgets = ['{} {}: '.format(prefix, name), Percentage(), ' ',
119
               Bar(marker='=', left='[', right=']'), ' ', ETA()]
120
    bar = ProgressBar(widgets=widgets, max_value=maxval, fd=sys.stdout).start()
121
    try:
122
        yield bar
123
    finally:
124
        bar.update(maxval)
125
        bar.finish()
126