1 | import os |
||
2 | import sys |
||
3 | from contextlib import contextmanager |
||
4 | from six import wraps |
||
0 ignored issues
–
show
Bug
introduced
by
Loading history...
|
|||
5 | |||
6 | import numpy |
||
7 | from progressbar import (ProgressBar, Percentage, Bar, ETA) |
||
8 | |||
9 | from fuel.datasets import H5PYDataset |
||
10 | from ..exceptions import MissingInputFiles |
||
11 | |||
12 | |||
13 | def check_exists(required_files): |
||
14 | """Decorator that checks if required files exist before running. |
||
15 | |||
16 | Parameters |
||
17 | ---------- |
||
18 | required_files : list of str |
||
19 | A list of strings indicating the filenames of regular files |
||
20 | (not directories) that should be found in the input directory |
||
21 | (which is the first argument to the wrapped function). |
||
22 | |||
23 | Returns |
||
24 | ------- |
||
25 | wrapper : function |
||
26 | A function that takes a function and returns a wrapped function. |
||
27 | The function returned by `wrapper` will include input file |
||
28 | existence verification. |
||
29 | |||
30 | Notes |
||
31 | ----- |
||
32 | Assumes that the directory in which to find the input files is |
||
33 | provided as the first argument, with the argument name `directory`. |
||
34 | |||
35 | """ |
||
36 | def function_wrapper(f): |
||
37 | @wraps(f) |
||
38 | def wrapped(directory, *args, **kwargs): |
||
39 | missing = [] |
||
40 | for filename in required_files: |
||
41 | if not os.path.isfile(os.path.join(directory, filename)): |
||
42 | missing.append(filename) |
||
43 | if len(missing) > 0: |
||
44 | raise MissingInputFiles('Required files missing', missing) |
||
45 | return f(directory, *args, **kwargs) |
||
46 | return wrapped |
||
47 | return function_wrapper |
||
48 | |||
49 | |||
50 | def fill_hdf5_file(h5file, data): |
||
51 | """Fills an HDF5 file in a H5PYDataset-compatible manner. |
||
52 | |||
53 | Parameters |
||
54 | ---------- |
||
55 | h5file : :class:`h5py.File` |
||
56 | File handle for an HDF5 file. |
||
57 | data : tuple of tuple |
||
58 | One element per split/source pair. Each element consists of a |
||
59 | tuple of (split_name, source_name, data_array, comment), where |
||
60 | |||
61 | * 'split_name' is a string identifier for the split name |
||
62 | * 'source_name' is a string identifier for the source name |
||
63 | * 'data_array' is a :class:`numpy.ndarray` containing the data |
||
64 | for this split/source pair |
||
65 | * 'comment' is a comment string for the split/source pair |
||
66 | |||
67 | The 'comment' element can optionally be omitted. |
||
68 | |||
69 | """ |
||
70 | # Check that all sources for a split have the same length |
||
71 | split_names = set(split_tuple[0] for split_tuple in data) |
||
72 | for name in split_names: |
||
73 | lengths = [len(split_tuple[2]) for split_tuple in data |
||
74 | if split_tuple[0] == name] |
||
75 | if not all(l == lengths[0] for l in lengths): |
||
76 | raise ValueError("split '{}' has sources that ".format(name) + |
||
77 | "vary in length") |
||
78 | |||
79 | # Initialize split dictionary |
||
80 | split_dict = dict([(split_name, {}) for split_name in split_names]) |
||
81 | |||
82 | # Compute total source lengths and check that splits have the same dtype |
||
83 | # across a source |
||
84 | source_names = set(split_tuple[1] for split_tuple in data) |
||
85 | for name in source_names: |
||
86 | splits = [s for s in data if s[1] == name] |
||
87 | indices = numpy.cumsum([0] + [len(s[2]) for s in splits]) |
||
88 | if not all(s[2].dtype == splits[0][2].dtype for s in splits): |
||
89 | raise ValueError("source '{}' has splits that ".format(name) + |
||
90 | "vary in dtype") |
||
91 | if not all(s[2].shape[1:] == splits[0][2].shape[1:] for s in splits): |
||
92 | raise ValueError("source '{}' has splits that ".format(name) + |
||
93 | "vary in shapes") |
||
94 | dataset = h5file.create_dataset( |
||
95 | name, (sum(len(s[2]) for s in splits),) + splits[0][2].shape[1:], |
||
96 | dtype=splits[0][2].dtype) |
||
97 | dataset[...] = numpy.concatenate([s[2] for s in splits], axis=0) |
||
98 | for i, j, s in zip(indices[:-1], indices[1:], splits): |
||
99 | if len(s) == 4: |
||
100 | split_dict[s[0]][name] = (i, j, None, s[3]) |
||
101 | else: |
||
102 | split_dict[s[0]][name] = (i, j) |
||
103 | h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict) |
||
104 | |||
105 | |||
106 | @contextmanager |
||
107 | def progress_bar(name, maxval, prefix='Converting'): |
||
108 | """Manages a progress bar for a conversion. |
||
109 | |||
110 | Parameters |
||
111 | ---------- |
||
112 | name : str |
||
113 | Name of the file being converted. |
||
114 | maxval : int |
||
115 | Total number of steps for the conversion. |
||
116 | |||
117 | """ |
||
118 | widgets = ['{} {}: '.format(prefix, name), Percentage(), ' ', |
||
119 | Bar(marker='=', left='[', right=']'), ' ', ETA()] |
||
120 | bar = ProgressBar(widgets=widgets, max_value=maxval, fd=sys.stdout).start() |
||
121 | try: |
||
122 | yield bar |
||
123 | finally: |
||
124 | bar.update(maxval) |
||
125 | bar.finish() |
||
126 |