Issues (119)

blocks/utils/testing.py (1 issue)

1
import logging
2
import os
3
import sys
4
import time
5
from six import wraps
6
from importlib import import_module
7
from unittest.case import SkipTest
8
9
from six import StringIO
10
11
import blocks
12
from blocks.algorithms import TrainingAlgorithm
13
from blocks.config import config
14
from blocks.main_loop import MainLoop
15
from fuel.datasets import IterableDataset
16
17
18
def silence_printing(test):
19
    @wraps(test)
20
    def wrapper(*args, **kwargs):
21
        stdout = sys.stdout
22
        sys.stdout = StringIO()
23
        logger = logging.getLogger(blocks.__name__)
24
        old_level = logger.level
25
        logger.setLevel(logging.ERROR)
26
        try:
27
            test(*args, **kwargs)
28
        finally:
29
            sys.stdout = stdout
30
            logger.setLevel(old_level)
31
    return wrapper
32
33
34
def skip_if_not_available(modules=None, datasets=None, configurations=None):
35
    """Raises a SkipTest exception when requirements are not met.
36
37
    Parameters
38
    ----------
39
    modules : list
40
        A list of strings of module names. If one of the modules fails to
41
        import, the test will be skipped.
42
    datasets : list
43
        A list of strings of folder names. If the data path is not
44
        configured, or the folder does not exist, the test is skipped.
45
    configurations : list
46
        A list of of strings of configuration names. If this configuration
47
        is not set and does not have a default, the test will be skipped.
48
49
    """
50
    if modules is None:
51
        modules = []
52
    if datasets is None:
53
        datasets = []
54
    if configurations is None:
55
        configurations = []
56
    for module in modules:
57
        try:
58
            import_module(module)
59
        except Exception:
60
            raise SkipTest
61
        if module == 'bokeh':
62
            ConnectionError = import_module(
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in ConnectionError.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
63
                'requests.exceptions').ConnectionError
64
            session = import_module('bokeh.session').Session()
65
            try:
66
                session.execute('get', session.base_url)
67
            except ConnectionError:
68
                raise SkipTest
69
70
    if datasets and not hasattr(config, 'data_path'):
71
        raise SkipTest
72
    for dataset in datasets:
73
        if not os.path.exists(os.path.join(config.data_path, dataset)):
74
            raise SkipTest
75
    for configuration in configurations:
76
        if not hasattr(config, configuration):
77
            raise SkipTest
78
79
80
def skip_if_configuration_set(configuration, value, message=None):
81
    """Raise SkipTest if a configuration option has a certain value.
82
83
    Parameters
84
    ----------
85
    configuration : str
86
        Configuration option to check.
87
    value : str
88
        Value of `blocks.config.<attribute>` which should cause
89
        a `SkipTest` to be raised.
90
    message : str, optional
91
        Reason for skipping the test.
92
93
    """
94
    if getattr(config, configuration) == value:
95
        if message is not None:
96
            raise SkipTest(message)
97
        else:
98
            raise SkipTest
99
100
101
class MockAlgorithm(TrainingAlgorithm):
102
    """An algorithm that only saves data.
103
104
    Also checks that the initialization routine is only called once.
105
106
    """
107
    def __init__(self, delay_time=0):
108
        self._initialized = False
109
        self.delay_time = delay_time
110
111
    def initialize(self):
112
        assert not self._initialized
113
        self._initialized = True
114
115
    def process_batch(self, batch):
116
        self.batch = batch
117
        time.sleep(self.delay_time)
118
119
120
class MockMainLoop(MainLoop):
121
    """Mock main loop with mock algorithm and simple data stream.
122
123
    Can be used with `main_loop = MagicMock(wraps=MockMainLoop())` to check
124
    which calls were made.
125
126
    """
127
    def __init__(self, delay_time=0, **kwargs):
128
        kwargs.setdefault('data_stream',
129
                          IterableDataset(range(10)).get_example_stream())
130
        kwargs.setdefault('algorithm', MockAlgorithm(delay_time))
131
        super(MockMainLoop, self).__init__(**kwargs)
132