Completed
Push — master ( f5ee62...ed481e )
by Dmitry
01:25
created

blocks.utils.skip_if_configuration_set()   A

Complexity

Conditions 3

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 19
rs 9.4285
1
import logging
2
import os
3
import sys
4
from functools import wraps
5
from importlib import import_module
6
from unittest.case import SkipTest
7
8
from six import StringIO
9
10
import blocks
11
from blocks.algorithms import TrainingAlgorithm
12
from blocks.config import config
13
from blocks.main_loop import MainLoop
14
from fuel.datasets import IterableDataset
15
16
17
def silence_printing(test):
18
    @wraps(test)
19
    def wrapper(*args, **kwargs):
20
        stdout = sys.stdout
21
        sys.stdout = StringIO()
22
        logger = logging.getLogger(blocks.__name__)
23
        old_level = logger.level
24
        logger.setLevel(logging.ERROR)
25
        try:
26
            test(*args, **kwargs)
27
        finally:
28
            sys.stdout = stdout
29
            logger.setLevel(old_level)
30
    return wrapper
31
32
33
def skip_if_not_available(modules=None, datasets=None, configurations=None):
34
    """Raises a SkipTest exception when requirements are not met.
35
36
    Parameters
37
    ----------
38
    modules : list
39
        A list of strings of module names. If one of the modules fails to
40
        import, the test will be skipped.
41
    datasets : list
42
        A list of strings of folder names. If the data path is not
43
        configured, or the folder does not exist, the test is skipped.
44
    configurations : list
45
        A list of of strings of configuration names. If this configuration
46
        is not set and does not have a default, the test will be skipped.
47
48
    """
49
    if modules is None:
50
        modules = []
51
    if datasets is None:
52
        datasets = []
53
    if configurations is None:
54
        configurations = []
55
    for module in modules:
56
        try:
57
            import_module(module)
58
        except Exception:
59
            raise SkipTest
60
        if module == 'bokeh':
61
            ConnectionError = import_module(
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in ConnectionError.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
Coding Style Naming introduced by
The name ConnectionError does not conform to the variable naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
62
                'requests.exceptions').ConnectionError
63
            session = import_module('bokeh.session').Session()
64
            try:
65
                session.execute('get', session.base_url)
66
            except ConnectionError:
67
                raise SkipTest
68
69
    if datasets and not hasattr(config, 'data_path'):
70
        raise SkipTest
71
    for dataset in datasets:
72
        if not os.path.exists(os.path.join(config.data_path, dataset)):
73
            raise SkipTest
74
    for configuration in configurations:
75
        if not hasattr(config, configuration):
76
            raise SkipTest
77
78
79
def skip_if_configuration_set(configuration, value, message=None):
80
    """Raise SkipTest if a configuration option has a certain value.
81
82
    Parameters
83
    ----------
84
    configuration : str
85
        Configuration option to check.
86
    value : str
87
        Value of `blocks.config.<attribute>` which should cause
88
        a `SkipTest` to be raised.
89
    message : str, optional
90
        Reason for skipping the test.
91
92
    """
93
    if getattr(config, configuration) == value:
94
        if message is not None:
95
            raise SkipTest(message)
96
        else:
97
            raise SkipTest
98
99
100
class MockAlgorithm(TrainingAlgorithm):
101
    """An algorithm that only saves data.
102
103
    Also checks that the initialization routine is only called once.
104
105
    """
106
    def __init__(self):
107
        self._initialized = False
108
109
    def initialize(self):
110
        assert not self._initialized
111
        self._initialized = True
112
113
    def process_batch(self, batch):
114
        self.batch = batch
115
116
117
class MockMainLoop(MainLoop):
118
    """Mock main loop with mock algorithm and simple data stream.
119
120
    Can be used with `main_loop = MagicMock(wraps=MockMainLoop())` to check
121
    which calls were made.
122
123
    """
124
    def __init__(self, **kwargs):
125
        kwargs.setdefault('data_stream',
126
                          IterableDataset(range(10)).get_example_stream())
127
        kwargs.setdefault('algorithm', MockAlgorithm())
128
        super(MockMainLoop, self).__init__(**kwargs)
129