|
1
|
|
|
import torch |
|
2
|
|
|
import torchio |
|
3
|
|
|
from torchio.data import LabelSampler |
|
4
|
|
|
from ...utils import TorchioTestCase |
|
5
|
|
|
|
|
6
|
|
|
|
|
7
|
|
|
class TestLabelSampler(TorchioTestCase): |
|
8
|
|
|
"""Tests for `LabelSampler` class.""" |
|
9
|
|
|
|
|
10
|
|
|
def test_label_sampler(self): |
|
11
|
|
|
sampler = LabelSampler(5) |
|
12
|
|
|
for patch in sampler(self.sample, num_patches=10): |
|
13
|
|
|
patch_center = patch['label'][torchio.DATA][0, 2, 2, 2] |
|
14
|
|
|
self.assertEqual(patch_center, 1) |
|
15
|
|
|
|
|
16
|
|
|
def test_label_probabilities(self): |
|
17
|
|
|
labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, -1) |
|
18
|
|
|
subject = torchio.Subject( |
|
19
|
|
|
label=torchio.Image(tensor=labels, type=torchio.LABEL), |
|
20
|
|
|
) |
|
21
|
|
|
sample = torchio.ImagesDataset([subject])[0] |
|
22
|
|
|
probs_dict = {0: 0, 1: 50, 2: 25, 3: 25} |
|
23
|
|
|
sampler = LabelSampler(5, 'label', label_probabilities=probs_dict) |
|
24
|
|
|
probabilities = sampler.get_probability_map(sample) |
|
25
|
|
|
fixture = torch.Tensor((0, 0, 2/12, 2/12, 3/12, 2/12, 0)) |
|
26
|
|
|
assert torch.all(probabilities.squeeze().eq(fixture)) |
|
27
|
|
|
|