Passed
Pull Request — master (#656)
by Fernando
01:13
created

TestRescaleIntensity.test_empty_mask()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
import copy
2
import torch
3
import torchio as tio
4
import numpy as np
5
from ...utils import TorchioTestCase
6
7
8
class TestRescaleIntensity(TorchioTestCase):
9
    """Tests for :class:`tio.RescaleIntensity` class."""
10
11
    def test_rescale_to_same_intentisy(self):
12
        min_t1 = float(self.sample_subject.t1.data.min())
13
        max_t1 = float(self.sample_subject.t1.data.max())
14
        transform = tio.RescaleIntensity(out_min_max=(min_t1, max_t1))
15
        transformed = transform(self.sample_subject)
16
        assert np.allclose(
17
            transformed.t1.data,
18
            self.sample_subject.t1.data,
19
            rtol=0,
20
            atol=1e-05,
21
        )
22
23
    def test_min_max(self):
24
        transform = tio.RescaleIntensity(out_min_max=(0, 1))
25
        transformed = transform(self.sample_subject)
26
        self.assertEqual(transformed.t1.data.min(), 0)
27
        self.assertEqual(transformed.t1.data.max(), 1)
28
29
    def test_percentiles(self):
30
        low_quantile = np.percentile(self.sample_subject.t1.data, 5)
31
        high_quantile = np.percentile(self.sample_subject.t1.data, 95)
32
        low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
33
            as_tuple=True)
34
        high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
35
            as_tuple=True)
36
        rescale = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(5, 95))
37
        transformed = rescale(self.sample_subject)
38
        assert (transformed.t1.data[low_indices] == 0).all()
39
        assert (transformed.t1.data[high_indices] == 1).all()
40
41
    def test_masking_using_label(self):
42
        transform = tio.RescaleIntensity(
43
            out_min_max=(0, 1), percentiles=(5, 95), masking_method='label')
44
        transformed = transform(self.sample_subject)
45
        mask = self.sample_subject.label.data > 0
46
        low_quantile = np.percentile(self.sample_subject.t1.data[mask], 5)
47
        high_quantile = np.percentile(self.sample_subject.t1.data[mask], 95)
48
        low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
49
            as_tuple=True)
50
        high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
51
            as_tuple=True)
52
        self.assertEqual(transformed.t1.data.min(), 0)
53
        self.assertEqual(transformed.t1.data.max(), 1)
54
        assert (transformed.t1.data[low_indices] == 0).all()
55
        assert (transformed.t1.data[high_indices] == 1).all()
56
57
    def test_ct(self):
58
        ct_max = 1500
59
        ct_min = -2000
60
        ct_range = ct_max - ct_min
61
        tensor = torch.rand(1, 30, 30, 30) * ct_range + ct_min
62
        ct = tio.ScalarImage(tensor=tensor)
63
        ct_air = -1000
64
        ct_bone = 1000
65
        rescale = tio.RescaleIntensity(
66
            out_min_max=(-1, 1),
67
            in_min_max=(ct_air, ct_bone),
68
        )
69
        rescaled = rescale(ct)
70
        assert rescaled.data.min() < -1
71
        assert rescaled.data.max() > 1
72
73
    def test_out_min_higher_than_out_max(self):
74
        with self.assertRaises(ValueError):
75
            tio.RescaleIntensity(out_min_max=(1, 0))
76
77
    def test_too_many_values_for_out_min_max(self):
78
        with self.assertRaises(ValueError):
79
            tio.RescaleIntensity(out_min_max=(1, 2, 3))
80
81
    def test_wrong_out_min_max_type(self):
82
        with self.assertRaises(ValueError):
83
            tio.RescaleIntensity(out_min_max='wrong')
84
85
    def test_min_percentile_higher_than_max_percentile(self):
86
        with self.assertRaises(ValueError):
87
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 0))
88
89
    def test_too_many_values_for_percentiles(self):
90
        with self.assertRaises(ValueError):
91
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 2, 3))
92
93
    def test_wrong_percentiles_type(self):
94
        with self.assertRaises(ValueError):
95
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles='wrong')
96
97
    def test_empty_mask(self):
98
        subject = copy.deepcopy(self.sample_subject)
99
        subject.label.set_data(subject.label.data * 0)
100
        rescale = tio.RescaleIntensity(masking_method='label')
101
        with self.assertWarns(RuntimeWarning):
102
            rescale(subject)
103