Passed
Pull Request — master (#499)
by Fernando
01:18
created

TestRescaleIntensity.test_ct()   A

Complexity

Conditions 1

Size

Total Lines 15
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 15
rs 9.7
c 0
b 0
f 0
cc 1
nop 1
1
import torch
2
import torchio as tio
3
import numpy as np
4
from ...utils import TorchioTestCase
5
6
7
class TestRescaleIntensity(TorchioTestCase):
8
    """Tests for :class:`tio.RescaleIntensity` class."""
9
10
    def test_rescale_to_same_intentisy(self):
11
        min_t1 = float(self.sample_subject.t1.data.min())
12
        max_t1 = float(self.sample_subject.t1.data.max())
13
        transform = tio.RescaleIntensity(out_min_max=(min_t1, max_t1))
14
        transformed = transform(self.sample_subject)
15
        assert np.allclose(
16
            transformed.t1.data,
17
            self.sample_subject.t1.data,
18
            rtol=0,
19
            atol=1e-05,
20
        )
21
22
    def test_min_max(self):
23
        transform = tio.RescaleIntensity(out_min_max=(0, 1))
24
        transformed = transform(self.sample_subject)
25
        self.assertEqual(transformed.t1.data.min(), 0)
26
        self.assertEqual(transformed.t1.data.max(), 1)
27
28
    def test_percentiles(self):
29
        low_quantile = np.percentile(self.sample_subject.t1.data, 5)
30
        high_quantile = np.percentile(self.sample_subject.t1.data, 95)
31
        low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
32
            as_tuple=True)
33
        high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
34
            as_tuple=True)
35
        rescale = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(5, 95))
36
        transformed = rescale(self.sample_subject)
37
        assert (transformed.t1.data[low_indices] == 0).all()
38
        assert (transformed.t1.data[high_indices] == 1).all()
39
40
    def test_masking_using_label(self):
41
        transform = tio.RescaleIntensity(
42
            out_min_max=(0, 1), percentiles=(5, 95), masking_method='label')
43
        transformed = transform(self.sample_subject)
44
        mask = self.sample_subject.label.data > 0
45
        low_quantile = np.percentile(self.sample_subject.t1.data[mask], 5)
46
        high_quantile = np.percentile(self.sample_subject.t1.data[mask], 95)
47
        low_indices = (self.sample_subject.t1.data < low_quantile).nonzero(
48
            as_tuple=True)
49
        high_indices = (self.sample_subject.t1.data > high_quantile).nonzero(
50
            as_tuple=True)
51
        self.assertEqual(transformed.t1.data.min(), 0)
52
        self.assertEqual(transformed.t1.data.max(), 1)
53
        assert (transformed.t1.data[low_indices] == 0).all()
54
        assert (transformed.t1.data[high_indices] == 1).all()
55
56
    def test_ct(self):
57
        ct_max = 1500
58
        ct_min = -2000
59
        ct_range = ct_max - ct_min
60
        tensor = torch.rand(1, 30, 30, 30) * ct_range + ct_min
61
        ct = tio.ScalarImage(tensor=tensor)
62
        ct_air = -1000
63
        ct_bone = 1000
64
        rescale = tio.RescaleIntensity(
65
            out_min_max=(-1, 1),
66
            in_min_max=(ct_air, ct_bone),
67
        )
68
        rescaled = rescale(ct)
69
        assert rescaled.data.min() < -1
70
        assert rescaled.data.max() > 1
71
72
    def test_out_min_higher_than_out_max(self):
73
        with self.assertRaises(ValueError):
74
            tio.RescaleIntensity(out_min_max=(1, 0))
75
76
    def test_too_many_values_for_out_min_max(self):
77
        with self.assertRaises(ValueError):
78
            tio.RescaleIntensity(out_min_max=(1, 2, 3))
79
80
    def test_wrong_out_min_max_type(self):
81
        with self.assertRaises(ValueError):
82
            tio.RescaleIntensity(out_min_max='wrong')
83
84
    def test_min_percentile_higher_than_max_percentile(self):
85
        with self.assertRaises(ValueError):
86
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 0))
87
88
    def test_too_many_values_for_percentiles(self):
89
        with self.assertRaises(ValueError):
90
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 2, 3))
91
92
    def test_wrong_percentiles_type(self):
93
        with self.assertRaises(ValueError):
94
            tio.RescaleIntensity(out_min_max=(0, 1), percentiles='wrong')
95