|
1
|
|
|
import torch |
|
2
|
|
|
|
|
3
|
|
|
from ....data import LabelMap |
|
4
|
|
|
from ...transform import Transform, TypeMaskingMethod |
|
5
|
|
|
from .remap_labels import RemapLabels |
|
6
|
|
|
|
|
7
|
|
|
|
|
8
|
|
|
class SequentialLabels(Transform): |
|
9
|
|
|
r"""Remap the integer IDs of labels in a LabelMap to be sequential. |
|
10
|
|
|
|
|
11
|
|
|
For example, if a label map has 6 labels with IDs (3, 5, 9, 15, 16, 23), |
|
12
|
|
|
then this will apply a :class:`~torchio.RemapLabels` transform with |
|
13
|
|
|
``remapping={3: 1, 5: 2, 9: 3, 15: 4, 16: 5, 23: 6}``. |
|
14
|
|
|
This transformation is always `fully invertible <invertibility>`_. |
|
15
|
|
|
|
|
16
|
|
|
Args: |
|
17
|
|
|
masking_method: See |
|
18
|
|
|
:class:`~torchio.RemapLabels`. |
|
19
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
|
20
|
|
|
keyword arguments. |
|
21
|
|
|
""" |
|
22
|
|
|
def __init__( |
|
23
|
|
|
self, |
|
24
|
|
|
masking_method: TypeMaskingMethod = None, |
|
25
|
|
|
**kwargs |
|
26
|
|
|
): |
|
27
|
|
|
super().__init__(**kwargs) |
|
28
|
|
|
self.masking_method = masking_method |
|
29
|
|
|
self.args_names = [] |
|
30
|
|
|
|
|
31
|
|
|
def apply_transform(self, subject): |
|
32
|
|
|
images_dict = subject.get_images_dict( |
|
33
|
|
|
intensity_only=False, |
|
34
|
|
|
include=self.include, |
|
35
|
|
|
exclude=self.exclude, |
|
36
|
|
|
) |
|
37
|
|
|
for name, image in images_dict.items(): |
|
38
|
|
|
if not isinstance(image, LabelMap): |
|
39
|
|
|
continue |
|
40
|
|
|
|
|
41
|
|
|
unique_labels = torch.unique(image.data) |
|
42
|
|
|
remapping = { |
|
43
|
|
|
unique_labels[i].item(): i |
|
44
|
|
|
for i in range(1, len(unique_labels)) |
|
45
|
|
|
} |
|
46
|
|
|
transform = RemapLabels( |
|
47
|
|
|
remapping=remapping, |
|
48
|
|
|
masking_method=self.masking_method, |
|
49
|
|
|
include=name, |
|
50
|
|
|
) |
|
51
|
|
|
subject = transform(subject) |
|
52
|
|
|
|
|
53
|
|
|
return subject |
|
54
|
|
|
|