1
|
|
|
import torch |
2
|
|
|
import torchio |
3
|
|
|
from torchio.data import WeightedSampler |
4
|
|
|
from ...utils import TorchioTestCase |
5
|
|
|
|
6
|
|
|
|
7
|
|
|
class TestWeightedSampler(TorchioTestCase): |
8
|
|
|
"""Tests for `WeightedSampler` class.""" |
9
|
|
|
|
10
|
|
|
def test_weighted_sampler(self): |
11
|
|
|
subject = self.get_sample((1, 7, 7, 7)) |
12
|
|
|
sampler = WeightedSampler(5, 'prob') |
13
|
|
|
patch = next(iter(sampler(subject))) |
14
|
|
|
self.assertEqual(tuple(patch['index_ini']), (1, 1, 1)) |
15
|
|
|
|
16
|
|
|
def get_sample(self, image_shape): |
17
|
|
|
t1 = torch.rand(*image_shape) |
18
|
|
|
prob = torch.zeros_like(t1) |
19
|
|
|
prob[0, 3, 3, 3] = 1 |
20
|
|
|
subject = torchio.Subject( |
21
|
|
|
t1=torchio.ScalarImage(tensor=t1), |
22
|
|
|
prob=torchio.ScalarImage(tensor=prob), |
23
|
|
|
) |
24
|
|
|
subject = torchio.SubjectsDataset([subject])[0] |
25
|
|
|
return subject |
26
|
|
|
|
27
|
|
|
def test_inconsistent_shape(self): |
28
|
|
|
# https://github.com/fepegar/torchio/issues/234#issuecomment-675029767 |
29
|
|
|
subject = torchio.Subject( |
30
|
|
|
im1=torchio.ScalarImage(tensor=torch.rand(1, 4, 5, 6)), |
31
|
|
|
im2=torchio.ScalarImage(tensor=torch.rand(2, 4, 5, 6)), |
32
|
|
|
) |
33
|
|
|
patch_size = 2 |
34
|
|
|
sampler = torchio.data.WeightedSampler(patch_size, 'im1') |
35
|
|
|
next(sampler(subject)) |
36
|
|
|
|