|
1
|
|
|
from typing import Dict |
|
2
|
|
|
|
|
3
|
|
|
from ....data import LabelMap |
|
4
|
|
|
from ...transform import Transform, TypeMaskingMethod |
|
5
|
|
|
|
|
6
|
|
|
|
|
7
|
|
|
class RemapLabels(Transform): |
|
8
|
|
|
r"""Remap the integer ids of labels in a LabelMap. |
|
9
|
|
|
|
|
10
|
|
|
This transformation may not be invertible if two labels are combined by the |
|
11
|
|
|
remapping. |
|
12
|
|
|
A masking method can be used to correctly split the label into two during |
|
13
|
|
|
the `inverse transformation <invertibility>`_ (see example). |
|
14
|
|
|
|
|
15
|
|
|
Args: |
|
16
|
|
|
remapping: Dictionary that specifies how labels should be remapped. |
|
17
|
|
|
The keys are the old label ids, and the corresponding values replace |
|
18
|
|
|
them. |
|
19
|
|
|
masking_method: Defines a mask for where the label remapping is applied. It can be one of: |
|
20
|
|
|
|
|
21
|
|
|
- ``None``: the mask image is all ones, i.e. all values in the image are used. |
|
22
|
|
|
|
|
23
|
|
|
- A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask, |
|
24
|
|
|
OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``, |
|
25
|
|
|
``'Inferior'``, ``'Superior'`` which specifies a side of the mask volume to be ones. |
|
26
|
|
|
|
|
27
|
|
|
- A function: the mask image is computed as a function of the intensity image. |
|
28
|
|
|
The function must receive and return a :class:`torch.Tensor`. |
|
29
|
|
|
|
|
30
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
|
31
|
|
|
keyword arguments. |
|
32
|
|
|
|
|
33
|
|
|
Example: |
|
34
|
|
|
>>> import torchio as tio |
|
35
|
|
|
>>> # Target label map has the following labels: |
|
36
|
|
|
>>> # {'left_ventricle': 1, 'right_ventricle': 2, 'left_caudate': 3, 'right_caudate': 4, |
|
37
|
|
|
>>> # 'left_putamen': 5, 'right_putamen': 6, 'left_thalamus': 7, 'right_thalamus': 8} |
|
38
|
|
|
>>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}) |
|
39
|
|
|
>>> # Merge right side labels with left side labels |
|
40
|
|
|
>>> transformed = transform(subject) |
|
41
|
|
|
>>> # Undesired behavior: The inverse transform will remap ALL left side labels to right side labels |
|
42
|
|
|
>>> # so the label map only has right side labels. |
|
43
|
|
|
>>> inverse_transformed = transformed.apply_inverse_transform() |
|
44
|
|
|
>>> # Here's the *right* way to do it with masking: |
|
45
|
|
|
>>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}, masking_method="Right") |
|
46
|
|
|
>>> # Remap the labels on the right side only (no difference yet). |
|
47
|
|
|
>>> transformed = transform(subject) |
|
48
|
|
|
>>> # Apply the inverse on the right side only. The labels are correctly split into left/right. |
|
49
|
|
|
>>> inverse_transformed = transformed.apply_inverse_transform() |
|
50
|
|
|
""" |
|
51
|
|
|
def __init__( |
|
52
|
|
|
self, |
|
53
|
|
|
remapping: Dict[int, int], |
|
54
|
|
|
masking_method: TypeMaskingMethod = None, |
|
55
|
|
|
**kwargs |
|
56
|
|
|
): |
|
57
|
|
|
super().__init__(**kwargs) |
|
58
|
|
|
self.kwargs = kwargs |
|
59
|
|
|
self.remapping = remapping |
|
60
|
|
|
self.masking_method = masking_method |
|
61
|
|
|
self.args_names = ('remapping', 'masking_method',) |
|
62
|
|
|
|
|
63
|
|
|
def apply_transform(self, subject): |
|
64
|
|
|
images = subject.get_images( |
|
65
|
|
|
intensity_only=False, |
|
66
|
|
|
include=self.include, |
|
67
|
|
|
exclude=self.exclude, |
|
68
|
|
|
) |
|
69
|
|
|
for image in images: |
|
70
|
|
|
if not isinstance(image, LabelMap): |
|
71
|
|
|
continue |
|
72
|
|
|
|
|
73
|
|
|
new_data = image.data.clone() |
|
74
|
|
|
mask = Transform.get_mask(self.masking_method, subject, new_data) |
|
75
|
|
|
for old_id, new_id in self.remapping.items(): |
|
76
|
|
|
new_data[mask & (image.data == old_id)] = new_id |
|
77
|
|
|
image.data = new_data |
|
78
|
|
|
|
|
79
|
|
|
return subject |
|
80
|
|
|
|
|
81
|
|
|
def is_invertible(self): |
|
82
|
|
|
return True |
|
83
|
|
|
|
|
84
|
|
|
def inverse(self): |
|
85
|
|
|
inverse_remapping = {v: k for k, v in self.remapping.items()} |
|
86
|
|
|
inverse_transform = RemapLabels( |
|
87
|
|
|
inverse_remapping, |
|
88
|
|
|
masking_method=self.masking_method, |
|
89
|
|
|
**self.kwargs, |
|
90
|
|
|
) |
|
91
|
|
|
return inverse_transform |
|
92
|
|
|
|