|
1
|
|
|
from .weighted import WeightedSampler |
|
2
|
|
|
|
|
3
|
|
|
|
|
4
|
|
|
class LabelSampler(WeightedSampler): |
|
5
|
|
|
r"""Extract random patches with labeled voxels at their center. |
|
6
|
|
|
|
|
7
|
|
|
This sampler yields patches whose center value is greater than 0 |
|
8
|
|
|
in the :py:attr:`label_name`. |
|
9
|
|
|
|
|
10
|
|
|
Args: |
|
11
|
|
|
patch_size: See :py:class:`~torchio.data.PatchSampler`. |
|
12
|
|
|
label_name: Name of the label image in the sample that will be used to |
|
13
|
|
|
generate the sampling probability map. |
|
14
|
|
|
|
|
15
|
|
|
Example: |
|
16
|
|
|
>>> import torchio |
|
17
|
|
|
>>> subject = torchio.datasets.Colin27() |
|
18
|
|
|
>>> subject |
|
19
|
|
|
Colin27(Keys: ('t1', 'head', 'brain'); images: 3) |
|
20
|
|
|
>>> sample = torchio.ImagesDataset([subject])[0] |
|
21
|
|
|
>>> sampler = torchio.data.LabelSampler(64, 'brain') |
|
22
|
|
|
>>> generator = sampler(sample) |
|
23
|
|
|
>>> for patch in generator: |
|
24
|
|
|
... print(patch.shape) |
|
25
|
|
|
|
|
26
|
|
|
If you want a specific number of patches from a volume, e.g. 10: |
|
27
|
|
|
|
|
28
|
|
|
>>> generator = sampler(sample, num_patches=10) |
|
29
|
|
|
>>> for patch in iterator: |
|
30
|
|
|
... print(patch.shape) |
|
31
|
|
|
|
|
32
|
|
|
""" |
|
33
|
|
|
def __init__(self, patch_size, label_name): |
|
34
|
|
|
super().__init__(patch_size, probability_map=label_name) |
|
35
|
|
|
|
|
36
|
|
|
def get_probability_map(self, sample): |
|
37
|
|
|
"""Return binarized image for sampling.""" |
|
38
|
|
|
if self.probability_map_name in sample: |
|
39
|
|
|
data = sample[self.probability_map_name].data > 0.5 |
|
40
|
|
|
else: |
|
41
|
|
|
message = ( |
|
42
|
|
|
f'Image "{self.probability_map_name}"' |
|
43
|
|
|
f' not found in subject sample: {sample}' |
|
44
|
|
|
) |
|
45
|
|
|
raise KeyError(message) |
|
46
|
|
|
return data |
|
47
|
|
|
|