1
|
|
|
import copy |
2
|
|
|
import torch |
3
|
|
|
import torchio as tio |
4
|
|
|
import numpy as np |
5
|
|
|
from ...utils import TorchioTestCase |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class TestRescaleIntensity(TorchioTestCase): |
9
|
|
|
"""Tests for :class:`tio.RescaleIntensity` class.""" |
10
|
|
|
|
11
|
|
|
def test_rescale_to_same_intentisy(self): |
12
|
|
|
min_t1 = float(self.sample_subject.t1.data.min()) |
13
|
|
|
max_t1 = float(self.sample_subject.t1.data.max()) |
14
|
|
|
transform = tio.RescaleIntensity(out_min_max=(min_t1, max_t1)) |
15
|
|
|
transformed = transform(self.sample_subject) |
16
|
|
|
assert np.allclose( |
17
|
|
|
transformed.t1.data, |
18
|
|
|
self.sample_subject.t1.data, |
19
|
|
|
rtol=0, |
20
|
|
|
atol=1e-05, |
21
|
|
|
) |
22
|
|
|
|
23
|
|
|
def test_min_max(self): |
24
|
|
|
transform = tio.RescaleIntensity(out_min_max=(0, 1)) |
25
|
|
|
transformed = transform(self.sample_subject) |
26
|
|
|
self.assertEqual(transformed.t1.data.min(), 0) |
27
|
|
|
self.assertEqual(transformed.t1.data.max(), 1) |
28
|
|
|
|
29
|
|
|
def test_percentiles(self): |
30
|
|
|
low_quantile = np.percentile(self.sample_subject.t1.data, 5) |
31
|
|
|
high_quantile = np.percentile(self.sample_subject.t1.data, 95) |
32
|
|
|
low_indices = (self.sample_subject.t1.data < low_quantile).nonzero( |
33
|
|
|
as_tuple=True) |
34
|
|
|
high_indices = (self.sample_subject.t1.data > high_quantile).nonzero( |
35
|
|
|
as_tuple=True) |
36
|
|
|
rescale = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(5, 95)) |
37
|
|
|
transformed = rescale(self.sample_subject) |
38
|
|
|
assert (transformed.t1.data[low_indices] == 0).all() |
39
|
|
|
assert (transformed.t1.data[high_indices] == 1).all() |
40
|
|
|
|
41
|
|
|
def test_masking_using_label(self): |
42
|
|
|
transform = tio.RescaleIntensity( |
43
|
|
|
out_min_max=(0, 1), percentiles=(5, 95), masking_method='label') |
44
|
|
|
transformed = transform(self.sample_subject) |
45
|
|
|
mask = self.sample_subject.label.data > 0 |
46
|
|
|
low_quantile = np.percentile(self.sample_subject.t1.data[mask], 5) |
47
|
|
|
high_quantile = np.percentile(self.sample_subject.t1.data[mask], 95) |
48
|
|
|
low_indices = (self.sample_subject.t1.data < low_quantile).nonzero( |
49
|
|
|
as_tuple=True) |
50
|
|
|
high_indices = (self.sample_subject.t1.data > high_quantile).nonzero( |
51
|
|
|
as_tuple=True) |
52
|
|
|
self.assertEqual(transformed.t1.data.min(), 0) |
53
|
|
|
self.assertEqual(transformed.t1.data.max(), 1) |
54
|
|
|
assert (transformed.t1.data[low_indices] == 0).all() |
55
|
|
|
assert (transformed.t1.data[high_indices] == 1).all() |
56
|
|
|
|
57
|
|
|
def test_ct(self): |
58
|
|
|
ct_max = 1500 |
59
|
|
|
ct_min = -2000 |
60
|
|
|
ct_range = ct_max - ct_min |
61
|
|
|
tensor = torch.rand(1, 30, 30, 30) * ct_range + ct_min |
62
|
|
|
ct = tio.ScalarImage(tensor=tensor) |
63
|
|
|
ct_air = -1000 |
64
|
|
|
ct_bone = 1000 |
65
|
|
|
rescale = tio.RescaleIntensity( |
66
|
|
|
out_min_max=(-1, 1), |
67
|
|
|
in_min_max=(ct_air, ct_bone), |
68
|
|
|
) |
69
|
|
|
rescaled = rescale(ct) |
70
|
|
|
assert rescaled.data.min() < -1 |
71
|
|
|
assert rescaled.data.max() > 1 |
72
|
|
|
|
73
|
|
|
def test_out_min_higher_than_out_max(self): |
74
|
|
|
with self.assertRaises(ValueError): |
75
|
|
|
tio.RescaleIntensity(out_min_max=(1, 0)) |
76
|
|
|
|
77
|
|
|
def test_too_many_values_for_out_min_max(self): |
78
|
|
|
with self.assertRaises(ValueError): |
79
|
|
|
tio.RescaleIntensity(out_min_max=(1, 2, 3)) |
80
|
|
|
|
81
|
|
|
def test_wrong_out_min_max_type(self): |
82
|
|
|
with self.assertRaises(ValueError): |
83
|
|
|
tio.RescaleIntensity(out_min_max='wrong') |
84
|
|
|
|
85
|
|
|
def test_min_percentile_higher_than_max_percentile(self): |
86
|
|
|
with self.assertRaises(ValueError): |
87
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 0)) |
88
|
|
|
|
89
|
|
|
def test_too_many_values_for_percentiles(self): |
90
|
|
|
with self.assertRaises(ValueError): |
91
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(1, 2, 3)) |
92
|
|
|
|
93
|
|
|
def test_wrong_percentiles_type(self): |
94
|
|
|
with self.assertRaises(ValueError): |
95
|
|
|
tio.RescaleIntensity(out_min_max=(0, 1), percentiles='wrong') |
96
|
|
|
|
97
|
|
|
def test_empty_mask(self): |
98
|
|
|
subject = copy.deepcopy(self.sample_subject) |
99
|
|
|
subject.label.set_data(subject.label.data * 0) |
100
|
|
|
rescale = tio.RescaleIntensity(masking_method='label') |
101
|
|
|
with self.assertWarns(RuntimeWarning): |
102
|
|
|
rescale(subject) |
103
|
|
|
|