Completed
Push — master ( 6ce243...b034e1 )
by Alexandre M.
52s
created

hansel.tests.test_parameter_grid()   F

Complexity

Conditions 17

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 17
dl 0
loc 39
rs 2.7205

How to fix   Complexity   

Complexity

Complex classes like hansel.tests.test_parameter_grid() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
3
# vi: set ft=python sts=4 ts=4 sw=4 et:
4
5
6
import pytest
7
8
from collections import Iterable, Sized
9
from itertools import chain, product
10
11
from hansel.utils import remove_duplicates, ParameterGrid
12
13
14
@pytest.fixture(scope="module")
15
def values(request):
16
    return list(range(3))
17
18
19
def test_remove_duplicates(values):
20
    assert remove_duplicates(values * 10) == sorted(values)
21
    assert remove_duplicates(values) == sorted(values)
22
    assert remove_duplicates(values) == sorted(remove_duplicates(values))
23
24
25
def test_parameter_grid():
26
    # Taken from sklearn and converted to pytest
27
    # Test basic properties of ParameterGrid.
28
29
    def assert_grid_iter_equals_getitem(grid):
30
        assert list(grid) == [grid[i] for i in range(len(grid))]
31
32
    params1 = {"foo": [1, 2, 3]}
33
    grid1 = ParameterGrid(params1)
34
    assert isinstance(grid1, Iterable)
35
    assert isinstance(grid1, Sized)
36
    assert len(grid1) == 3
37
    assert_grid_iter_equals_getitem(grid1)
38
39
    params2 = {"foo": [4, 2],
40
               "bar": ["ham", "spam", "eggs"]}
41
    grid2 = ParameterGrid(params2)
42
    assert len(grid2) == 6
43
44
    # loop to assert we can iterate over the grid multiple times
45
    for i in range(2):
46
        # tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2)
47
        points = set(tuple(chain(*(sorted(p.items())))) for p in grid2)
48
        assert points == set(("bar", x, "foo", y)
49
                         for x, y in product(params2["bar"], params2["foo"]))
50
51
    assert_grid_iter_equals_getitem(grid2)
52
53
    # Special case: empty grid (useful to get default estimator settings)
54
    empty = ParameterGrid({})
55
    assert len(empty) == 1
56
    assert list(empty) == [{}]
57
    assert_grid_iter_equals_getitem(empty)
58
    pytest.raises(IndexError, lambda: empty[1])
59
60
    has_empty = ParameterGrid([{'C': [1, 10]}, {}, {'C': [.5]}])
61
    assert len(has_empty) == 4
62
    assert list(has_empty) == [{'C': 1}, {'C': 10}, {}, {'C': .5}]
63
    assert_grid_iter_equals_getitem(has_empty)
64
65
66