1 | import logging |
||
2 | import os |
||
3 | import sys |
||
4 | import time |
||
5 | from six import wraps |
||
6 | from importlib import import_module |
||
7 | from unittest.case import SkipTest |
||
8 | |||
9 | from six import StringIO |
||
10 | |||
11 | import blocks |
||
12 | from blocks.algorithms import TrainingAlgorithm |
||
13 | from blocks.config import config |
||
14 | from blocks.main_loop import MainLoop |
||
15 | from fuel.datasets import IterableDataset |
||
16 | |||
17 | |||
18 | def silence_printing(test): |
||
19 | @wraps(test) |
||
20 | def wrapper(*args, **kwargs): |
||
21 | stdout = sys.stdout |
||
22 | sys.stdout = StringIO() |
||
23 | logger = logging.getLogger(blocks.__name__) |
||
24 | old_level = logger.level |
||
25 | logger.setLevel(logging.ERROR) |
||
26 | try: |
||
27 | test(*args, **kwargs) |
||
28 | finally: |
||
29 | sys.stdout = stdout |
||
30 | logger.setLevel(old_level) |
||
31 | return wrapper |
||
32 | |||
33 | |||
34 | def skip_if_not_available(modules=None, datasets=None, configurations=None): |
||
35 | """Raises a SkipTest exception when requirements are not met. |
||
36 | |||
37 | Parameters |
||
38 | ---------- |
||
39 | modules : list |
||
40 | A list of strings of module names. If one of the modules fails to |
||
41 | import, the test will be skipped. |
||
42 | datasets : list |
||
43 | A list of strings of folder names. If the data path is not |
||
44 | configured, or the folder does not exist, the test is skipped. |
||
45 | configurations : list |
||
46 | A list of of strings of configuration names. If this configuration |
||
47 | is not set and does not have a default, the test will be skipped. |
||
48 | |||
49 | """ |
||
50 | if modules is None: |
||
51 | modules = [] |
||
52 | if datasets is None: |
||
53 | datasets = [] |
||
54 | if configurations is None: |
||
55 | configurations = [] |
||
56 | for module in modules: |
||
57 | try: |
||
58 | import_module(module) |
||
59 | except Exception: |
||
60 | raise SkipTest |
||
61 | if module == 'bokeh': |
||
62 | ConnectionError = import_module( |
||
0 ignored issues
–
show
The name
ConnectionError does not conform to the variable naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
63 | 'requests.exceptions').ConnectionError |
||
64 | session = import_module('bokeh.session').Session() |
||
65 | try: |
||
66 | session.execute('get', session.base_url) |
||
67 | except ConnectionError: |
||
68 | raise SkipTest |
||
69 | |||
70 | if datasets and not hasattr(config, 'data_path'): |
||
71 | raise SkipTest |
||
72 | for dataset in datasets: |
||
73 | if not os.path.exists(os.path.join(config.data_path, dataset)): |
||
74 | raise SkipTest |
||
75 | for configuration in configurations: |
||
76 | if not hasattr(config, configuration): |
||
77 | raise SkipTest |
||
78 | |||
79 | |||
80 | def skip_if_configuration_set(configuration, value, message=None): |
||
81 | """Raise SkipTest if a configuration option has a certain value. |
||
82 | |||
83 | Parameters |
||
84 | ---------- |
||
85 | configuration : str |
||
86 | Configuration option to check. |
||
87 | value : str |
||
88 | Value of `blocks.config.<attribute>` which should cause |
||
89 | a `SkipTest` to be raised. |
||
90 | message : str, optional |
||
91 | Reason for skipping the test. |
||
92 | |||
93 | """ |
||
94 | if getattr(config, configuration) == value: |
||
95 | if message is not None: |
||
96 | raise SkipTest(message) |
||
97 | else: |
||
98 | raise SkipTest |
||
99 | |||
100 | |||
101 | class MockAlgorithm(TrainingAlgorithm): |
||
102 | """An algorithm that only saves data. |
||
103 | |||
104 | Also checks that the initialization routine is only called once. |
||
105 | |||
106 | """ |
||
107 | def __init__(self, delay_time=0): |
||
108 | self._initialized = False |
||
109 | self.delay_time = delay_time |
||
110 | |||
111 | def initialize(self): |
||
112 | assert not self._initialized |
||
113 | self._initialized = True |
||
114 | |||
115 | def process_batch(self, batch): |
||
116 | self.batch = batch |
||
117 | time.sleep(self.delay_time) |
||
118 | |||
119 | |||
120 | class MockMainLoop(MainLoop): |
||
121 | """Mock main loop with mock algorithm and simple data stream. |
||
122 | |||
123 | Can be used with `main_loop = MagicMock(wraps=MockMainLoop())` to check |
||
124 | which calls were made. |
||
125 | |||
126 | """ |
||
127 | def __init__(self, delay_time=0, **kwargs): |
||
128 | kwargs.setdefault('data_stream', |
||
129 | IterableDataset(range(10)).get_example_stream()) |
||
130 | kwargs.setdefault('algorithm', MockAlgorithm(delay_time)) |
||
131 | super(MockMainLoop, self).__init__(**kwargs) |
||
132 |
It is generally discouraged to redefine built-ins as this makes code very hard to read.