Passed
Push — master ( b9ac52...6aebda )
by Fernando
10:37 queued 20s
created

RemapLabels.__init__()   A

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 10
nop 4
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
1
from typing import Dict
2
3
from ....data import LabelMap
4
from ...transform import Transform, TypeMaskingMethod
5
6
7
class RemapLabels(Transform):
8
    r"""Remap the integer ids of labels in a LabelMap.
9
10
    This transformation may not be invertible if two labels are combined by the
11
    remapping.
12
    A masking method can be used to correctly split the label into two during
13
    the `inverse transformation <invertibility>`_ (see example).
14
15
    Args:
16
        remapping: Dictionary that specifies how labels should be remapped.
17
            The keys are the old label ids, and the corresponding values replace
18
            them.
19
        masking_method: Defines a mask for where the label remapping is applied. It can be one of:
20
21
            - ``None``: the mask image is all ones, i.e. all values in the image are used.
22
23
            - A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask,
24
              OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
25
              ``'Inferior'``, ``'Superior'`` which specifies a side of the mask volume to be ones.
26
27
            - A function: the mask image is computed as a function of the intensity image.
28
              The function must receive and return a :class:`torch.Tensor`.
29
30
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
31
            keyword arguments.
32
33
    Example:
34
        >>> import torchio as tio
35
        >>> # Target label map has the following labels:
36
        >>> # {'left_ventricle': 1, 'right_ventricle': 2, 'left_caudate': 3, 'right_caudate': 4,
37
        >>> #  'left_putamen': 5, 'right_putamen': 6, 'left_thalamus': 7, 'right_thalamus': 8}
38
        >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7})
39
        >>> # Merge right side labels with left side labels
40
        >>> transformed = transform(subject)
41
        >>> # Undesired behavior: The inverse transform will remap ALL left side labels to right side labels
42
        >>> # so the label map only has right side labels.
43
        >>> inverse_transformed = transformed.apply_inverse_transform()
44
        >>> # Here's the *right* way to do it with masking:
45
        >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}, masking_method="Right")
46
        >>> # Remap the labels on the right side only (no difference yet).
47
        >>> transformed = transform(subject)
48
        >>> # Apply the inverse on the right side only. The labels are correctly split into left/right.
49
        >>> inverse_transformed = transformed.apply_inverse_transform()
50
    """
51
    def __init__(
52
            self,
53
            remapping: Dict[int, int],
54
            masking_method: TypeMaskingMethod = None,
55
            **kwargs
56
            ):
57
        super().__init__(**kwargs)
58
        self.kwargs = kwargs
59
        self.remapping = remapping
60
        self.masking_method = masking_method
61
        self.args_names = ('remapping', 'masking_method',)
62
63
    def apply_transform(self, subject):
64
        images = subject.get_images(
65
            intensity_only=False,
66
            include=self.include,
67
            exclude=self.exclude,
68
        )
69
        for image in images:
70
            if not isinstance(image, LabelMap):
71
                continue
72
73
            new_data = image.data.clone()
74
            mask = Transform.get_mask(self.masking_method, subject, new_data)
75
            for old_id, new_id in self.remapping.items():
76
                new_data[mask & (image.data == old_id)] = new_id
77
            image.data = new_data
78
79
        return subject
80
81
    def is_invertible(self):
82
        return True
83
84
    def inverse(self):
85
        inverse_remapping = {v: k for k, v in self.remapping.items()}
86
        inverse_transform = RemapLabels(
87
            inverse_remapping,
88
            masking_method=self.masking_method,
89
            **self.kwargs,
90
        )
91
        return inverse_transform
92