Passed
Pull Request — master (#527)
by Fernando
01:25
created

torchio.transforms.preprocessing.intensity.mask   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 70
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 6
eloc 41
dl 0
loc 70
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A Mask.mask() 0 11 2
A Mask.__init__() 0 12 1
A Mask.apply_transform() 0 9 2
A Mask.apply_masking() 0 3 1
1
from typing import Optional, Sequence
2
3
import torch
4
import numpy as np
5
6
from ....data.image import LabelMap
7
from ....data.subject import Subject
8
from ....transforms.transform import TypeMaskingMethod
9
from ... import IntensityTransform
10
11
12
class Mask(IntensityTransform):
13
    """Set voxels outside of mask to a constant value.
14
15
    Args:
16
        masking_method: See
17
            :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`.
18
        outside_value: Value to set for all voxels outside of the mask.
19
        labels: If a label map is used to generate the mask,
20
            sequence of labels to consider.
21
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
22
            keyword arguments.
23
24
    Example:
25
        >>> import torchio as tio
26
        >>> subject = tio.datasets.Colin27()
27
        >>> subject
28
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
29
        >>> mask = tio.Mask(masking_method='brain')  # Use "brain" image to mask
30
        >>> transformed = mask(subject)  # Set values outside of the brain to 0
31
    """  # noqa: E501
32
    def __init__(
33
            self,
34
            masking_method: TypeMaskingMethod,
35
            outside_value: float = 0,
36
            labels: Optional[Sequence[int]] = None,
37
            **kwargs,
38
            ):
39
        super().__init__(**kwargs)
40
        self.masking_method = masking_method
41
        self.masking_labels = labels
42
        self.outside_value = outside_value
43
        self.args_names = ('masking_method',)
44
45
    def apply_transform(self, subject: Subject) -> Subject:
46
        for image in self.get_images(subject):
47
            label_map = self.get_mask_from_masking_method(
48
                self.masking_method,
49
                subject,
50
                image.data,
51
            )
52
            self.apply_masking(image, label_map)
53
        return subject
54
55
    def apply_masking(self, image: LabelMap, label_map: torch.Tensor) -> None:
56
        masked = self.mask(image.data, label_map, self.masking_labels)
57
        image.set_data(masked)
58
59
    def mask(
60
            self,
61
            tensor: torch.Tensor,
62
            label_map: torch.Tensor,
63
            labels: Optional[Sequence[int]] = None,
64
            ) -> torch.Tensor:
65
        array = tensor.clone().numpy()
66
        label_map = label_map.numpy()
67
        mask = label_map if labels is None else np.isin(label_map, labels)
68
        array[~mask] = self.outside_value
69
        return torch.as_tensor(array)
70