Completed
Push — master ( e9e8af...67a2e4 )
by Fernando
10:41 queued 09:14
created

Transform.parse_interpolation()   A

Complexity

Conditions 4

Size

Total Lines 26
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 18
nop 1
dl 0
loc 26
rs 9.5
c 0
b 0
f 0
1
import numbers
2
import warnings
3
from typing import Union
4
from copy import deepcopy
5
from abc import ABC, abstractmethod
6
7
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
8
import SimpleITK as sitk
0 ignored issues
show
introduced by
Unable to import 'SimpleITK'
Loading history...
9
10
from .. import TypeData, INTENSITY, DATA
11
from ..data.image import Image
12
from ..data.subject import Subject
13
from ..data.dataset import ImagesDataset
14
from ..utils import nib_to_sitk, sitk_to_nib
15
from .interpolation import Interpolation
16
17
18
class Transform(ABC):
19
    """Abstract class for all TorchIO transforms.
20
21
    All classes used to transform a sample from an
22
    :py:class:`~torchio.ImagesDataset` should subclass it.
23
    All subclasses should overwrite
24
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
25
    which takes a sample, applies some transformation and returns the result.
26
27
    Args:
28
        p: Probability that this transform will be applied.
29
    """
30
    def __init__(self, p: float = 1):
31
        self.probability = self.parse_probability(p)
32
33
    def __call__(self, data: Union[Subject, torch.Tensor]):
34
        """Transform a sample and return the result.
35
36
        Args:
37
            data: Instance of :py:class:`~torchio.Subject` or 4D
38
                :py:class:`torch.Tensor` with dimensions :math:`(C, D, H, W)`,
39
                where :math:`C` is the number of channels and :math:`D, H, W`
40
                are the spatial dimensions. If the input is a tensor, the affine
41
                matrix is an identity and a tensor will be also returned.
42
        """
43
        if isinstance(data, torch.Tensor):
44
            is_tensor = True
45
            sample = self.parse_tensor(data)
46
        else:
47
            is_tensor = False
48
            sample = data
49
        self.parse_sample(sample)
50
        if torch.rand(1).item() > self.probability:
51
            return sample
52
        sample = deepcopy(sample)
53
        transformed = self.apply_transform(sample)
54
        if is_tensor:
55
            num_channels = len(data)
56
            images = [
57
                transformed[f'channel_{i}'][DATA]
58
                for i in range(num_channels)
59
            ]
60
            transformed = torch.cat(images)
61
        return transformed
62
63
    @abstractmethod
64
    def apply_transform(self, sample: Subject):
65
        raise NotImplementedError
66
67
    @staticmethod
68
    def parse_probability(probability: float) -> float:
69
        is_number = isinstance(probability, numbers.Number)
70
        if not (is_number and 0 <= probability <= 1):
71
            message = (
72
                'Probability must be a number in [0, 1],'
73
                f' not {probability}'
74
            )
75
            raise ValueError(message)
76
        return probability
77
78
    @staticmethod
79
    def parse_sample(sample: Subject) -> None:
80
        if not isinstance(sample, Subject) or not sample.is_sample:
81
            message = (
82
                'Inputs to transforms must be instances of torchio.Subject'
83
                f' generated by a torchio.ImagesDataset, not "{type(sample)}"'
84
            )
85
            raise RuntimeError(message)
86
87
    def parse_tensor(self, tensor: torch.Tensor) -> Subject:
88
        num_dimensions = tensor.dim()
89
        if num_dimensions != 4:
90
            message = (
91
                'The input tensor must have 4 dimensions (channels, i, j, k),'
92
                f' but has {num_dimensions}: {tensor.shape}'
93
            )
94
            raise RuntimeError(message)
95
        return self._get_subject_from_tensor(tensor)
96
97
    @staticmethod
98
    def parse_interpolation(interpolation: str) -> Interpolation:
99
        if isinstance(interpolation, Interpolation):
100
            message = (
101
                'Interpolation of type torchio.Interpolation'
102
                ' is deprecated, please use a string instead'
103
            )
104
            warnings.warn(message, FutureWarning)
105
        elif isinstance(interpolation, str):
106
            interpolation = interpolation.lower()
107
            supported_values = [key.name.lower() for key in Interpolation]
108
            if interpolation in supported_values:
109
                interpolation = getattr(Interpolation, interpolation.upper())
110
            else:
111
                message = (
112
                    f'Interpolation "{interpolation}" is not among'
113
                    f' the supported values: {supported_values}'
114
                )
115
                raise AttributeError(message)
116
        else:
117
            message = (
118
                'image_interpolation must be a string,'
119
                f' not {type(interpolation)}'
120
            )
121
            raise TypeError(message)
122
        return interpolation
123
124
    @staticmethod
125
    def _get_subject_from_tensor(tensor: torch.Tensor) -> Subject:
126
        subject_dict = {}
127
        for channel_index, channel_tensor in enumerate(tensor):
128
            name = f'channel_{channel_index}'
129
            image = Image(tensor=channel_tensor, type=INTENSITY)
130
            subject_dict[name] = image
131
        subject = Subject(subject_dict)
132
        dataset = ImagesDataset([subject])
133
        sample = dataset[0]
134
        return sample
135
136
    @staticmethod
137
    def nib_to_sitk(data: TypeData, affine: TypeData):
138
        return nib_to_sitk(data, affine)
139
140
    @staticmethod
141
    def sitk_to_nib(image: sitk.Image):
142
        return sitk_to_nib(image)
143
144
    @property
145
    def name(self):
146
        return self.__class__.__name__
147