Passed
Push — master ( b9ac52...6aebda )
by Fernando
10:37 queued 20s
created

SequentialLabels.apply_transform()   A

Complexity

Conditions 3

Size

Total Lines 23
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 18
nop 2
dl 0
loc 23
rs 9.5
c 0
b 0
f 0
1
import torch
2
3
from ....data import LabelMap
4
from ...transform import Transform, TypeMaskingMethod
5
from .remap_labels import RemapLabels
6
7
8
class SequentialLabels(Transform):
9
    r"""Remap the integer IDs of labels in a LabelMap to be sequential.
10
11
    For example, if a label map has 6 labels with IDs (3, 5, 9, 15, 16, 23),
12
    then this will apply a :class:`~torchio.RemapLabels` transform with
13
    ``remapping={3: 1, 5: 2, 9: 3, 15: 4, 16: 5, 23: 6}``.
14
    This transformation is always `fully invertible <invertibility>`_.
15
16
    Args:
17
        masking_method: See
18
            :class:`~torchio.RemapLabels`.
19
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
20
            keyword arguments.
21
    """
22
    def __init__(
23
            self,
24
            masking_method: TypeMaskingMethod = None,
25
            **kwargs
26
            ):
27
        super().__init__(**kwargs)
28
        self.masking_method = masking_method
29
        self.args_names = []
30
31
    def apply_transform(self, subject):
32
        images_dict = subject.get_images_dict(
33
            intensity_only=False,
34
            include=self.include,
35
            exclude=self.exclude,
36
        )
37
        for name, image in images_dict.items():
38
            if not isinstance(image, LabelMap):
39
                continue
40
41
            unique_labels = torch.unique(image.data)
42
            remapping = {
43
                unique_labels[i].item(): i
44
                for i in range(1, len(unique_labels))
45
            }
46
            transform = RemapLabels(
47
                remapping=remapping,
48
                masking_method=self.masking_method,
49
                include=name,
50
            )
51
            subject = transform(subject)
52
53
        return subject
54