Passed
Push — master ( af8682...a51248 )
by Fernando
01:26
created

torchio.transforms.preprocessing.intensity.mask   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 69
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 5
eloc 38
dl 0
loc 69
rs 10
c 0
b 0
f 0

3 Methods

Rating   Name   Duplication   Size   Complexity  
A Mask.apply_masking() 0 3 1
A Mask.__init__() 0 11 1
A Mask.apply_transform() 0 10 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A mask() 0 9 1
1
from typing import Optional, Sequence
2
3
import torch
4
5
from ....data.image import LabelMap
6
from ....data.subject import Subject
7
from ....transforms.transform import TypeMaskingMethod
8
from ... import IntensityTransform
9
10
11
class Mask(IntensityTransform):
12
    """Set voxels outside of mask to a constant value.
13
14
    Args:
15
        masking_method: See
16
            :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`.
17
        outside_value: Value to set for all voxels outside of the mask.
18
        labels: If a label map is used to generate the mask,
19
            sequence of labels to consider.
20
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
21
            keyword arguments.
22
23
    Example:
24
        >>> import torchio as tio
25
        >>> subject = tio.datasets.Colin27()
26
        >>> subject
27
        Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
28
        >>> mask = tio.Mask(masking_method='brain')  # Use "brain" image to mask
29
        >>> transformed = mask(subject)  # Set voxels outside of the brain to 0
30
31
    """  # noqa: E501
32
    def __init__(
33
            self,
34
            masking_method: TypeMaskingMethod,
35
            outside_value: float = 0,
36
            labels: Optional[Sequence[int]] = None,
37
            **kwargs,
38
            ):
39
        super().__init__(**kwargs)
40
        self.masking_method = masking_method
41
        self.masking_labels = labels
42
        self.outside_value = outside_value
43
44
    def apply_transform(self, subject: Subject) -> Subject:
45
        for image in self.get_images(subject):
46
            mask_data = self.get_mask_from_masking_method(
47
                self.masking_method,
48
                subject,
49
                image.data,
50
                self.masking_labels,
51
            )
52
            self.apply_masking(image, mask_data)
53
        return subject
54
55
    def apply_masking(self, image: LabelMap, mask_data: torch.Tensor) -> None:
56
        masked = mask(image.data, mask_data, self.outside_value)
57
        image.set_data(masked)
58
59
60
def mask(
61
        tensor: torch.Tensor,
62
        mask: torch.Tensor,
63
        outside_value: float,
64
        ) -> torch.Tensor:
65
    array = tensor.clone().numpy()
66
    mask = mask.numpy()
67
    array[~mask] = outside_value
68
    return torch.as_tensor(array)
69