Passed
Push — master ( 0d2a88...eb3c35 )
by Fernando
01:36
created

tests.test_utils.TestUtils.test_to_tuple()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 1
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
"""Tests for `utils` package."""
4
5
import unittest
6
import torch
7
import numpy as np
8
import SimpleITK as sitk
9
from torchio import LABEL, INTENSITY, RandomFlip
10
from torchio.utils import (
11
    to_tuple,
12
    get_stem,
13
    guess_type,
14
    nib_to_sitk,
15
    sitk_to_nib,
16
    apply_transform_to_file,
17
)
18
from .utils import TorchioTestCase
19
20
21
class TestUtils(TorchioTestCase):
22
    """Tests for `utils` module."""
23
24
    def test_to_tuple(self):
25
        assert to_tuple(1) == (1,)
26
        assert to_tuple((1,)) == (1,)
27
        assert to_tuple(1, length=3) == (1, 1, 1)
28
        assert to_tuple((1, 2)) == (1, 2)
29
        assert to_tuple((1, 2), length=3) == (1, 2)
30
        assert to_tuple([1, 2], length=3) == (1, 2)
31
32
    def test_get_stem(self):
33
        assert get_stem('/home/image.nii.gz') == 'image'
34
        assert get_stem('/home/image.nii') == 'image'
35
        assert get_stem('/home/image.nrrd') == 'image'
36
37
    def test_guess_type(self):
38
        assert guess_type('None') is None
39
        assert isinstance(guess_type('1'), int)
40
        assert isinstance(guess_type('1.5'), float)
41
        assert isinstance(guess_type('(1, 3, 5)'), tuple)
42
        assert isinstance(guess_type('(1,3,5)'), tuple)
43
        assert isinstance(guess_type('[1,3,5]'), list)
44
        assert isinstance(guess_type('test'), str)
45
46
    def test_check_consistent_shape(self):
47
        good_sample = self.sample
48
        bad_sample = self.get_inconsistent_sample()
49
        good_sample.check_consistent_shape()
50
        with self.assertRaises(ValueError):
51
            bad_sample.check_consistent_shape()
52
53
    def test_apply_transform_to_file(self):
54
        transform = RandomFlip()
55
        apply_transform_to_file(
56
            self.get_image_path('input'),
57
            transform,
58
            self.get_image_path('output'),
59
            verbose=True,
60
        )
61
62
    def test_sitk_to_nib(self):
63
        data = np.random.rand(10, 10)
64
        image = sitk.GetImageFromArray(data)
65
        tensor, affine = sitk_to_nib(image)
66
        self.assertAlmostEqual(data.sum(), tensor.sum())
67
68
69
class TestNibabelToSimpleITK(TorchioTestCase):
70
    def setUp(self):
71
        super().setUp()
72
        self.affine = np.eye(4)
73
74
    def test_wrong_dims(self):
75
        with self.assertRaises(ValueError):
76
            nib_to_sitk(np.random.rand(10, 10), self.affine)
77
78
    def test_2d_single(self):
79
        data = np.random.rand(1, 1, 10, 12)
80
        image = nib_to_sitk(data, self.affine)
81
        assert image.GetDimension() == 2
82
        assert image.GetSize() == (10, 12)
83
        assert image.GetNumberOfComponentsPerPixel() == 1
84
85
    def test_2d_multi(self):
86
        data = np.random.rand(5, 1, 10, 12)
87
        image = nib_to_sitk(data, self.affine)
88
        assert image.GetDimension() == 2
89
        assert image.GetSize() == (10, 12)
90
        assert image.GetNumberOfComponentsPerPixel() == 5
91
92
    def test_2d_3d_single(self):
93
        data = np.random.rand(1, 1, 10, 12)
94
        image = nib_to_sitk(data, self.affine, force_3d=True)
95
        assert image.GetDimension() == 3
96
        assert image.GetSize() == (1, 10, 12)
97
        assert image.GetNumberOfComponentsPerPixel() == 1
98
99
    def test_2d_3d_multi(self):
100
        data = np.random.rand(5, 1, 10, 12)
101
        image = nib_to_sitk(data, self.affine, force_3d=True)
102
        assert image.GetDimension() == 3
103
        assert image.GetSize() == (1, 10, 12)
104
        assert image.GetNumberOfComponentsPerPixel() == 5
105
106
    def test_3d_single(self):
107
        data = np.random.rand(1, 8, 10, 12)
108
        image = nib_to_sitk(data, self.affine)
109
        assert image.GetDimension() == 3
110
        assert image.GetSize() == (8, 10, 12)
111
        assert image.GetNumberOfComponentsPerPixel() == 1
112
113
    def test_3d_multi(self):
114
        data = np.random.rand(5, 8, 10, 12)
115
        image = nib_to_sitk(data, self.affine)
116
        assert image.GetDimension() == 3
117
        assert image.GetSize() == (8, 10, 12)
118
        assert image.GetNumberOfComponentsPerPixel() == 5
119