Passed
Push — master ( ddc71b...32d696 )
by Fernando
01:11
created

Transform.parse_sample()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from .. import TypeData, DATA, AFFINE, TypeNumber
13
from ..data.subject import Subject
14
from ..data.image import Image, ScalarImage
15
from ..data.dataset import ImagesDataset
16
from ..utils import nib_to_sitk, sitk_to_nib
17
from .interpolation import Interpolation
18
19
20
IMAGE_NAME = 'image'
21
22
23
class Transform(ABC):
24
    """Abstract class for all TorchIO transforms.
25
26
    All classes used to transform a sample from an
27
    :py:class:`~torchio.ImagesDataset` should subclass it.
28
    All subclasses should overwrite
29
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
30
    which takes a sample, applies some transformation and returns the result.
31
32
    Args:
33
        p: Probability that this transform will be applied.
34
        copy: Make a shallow copy of the input before applying the transform.
35
        keys: If the input is a dictionary, the corresponding values will be
36
            converted to :py:class:`torchio.ScalarImage` so that the transform
37
            is applied to them only.
38
    """
39
    def __init__(
40
            self,
41
            p: float = 1,
42
            copy: bool = True,
43
            keys: Optional[List[str]] = None,
44
            ):
45
        self.probability = self.parse_probability(p)
46
        self.copy = copy
47
        self.keys = keys
48
        self.image_name = 'default_image_name'
49
50
    def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]):
51
        """Transform a sample and return the result.
52
53
        Args:
54
            data: Instance of :py:class:`~torchio.Subject`, 4D
55
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
56
                :math:`(C, D, H, W)`, where :math:`C` is the number of channels
57
                and :math:`D, H, W` are the spatial dimensions. If the input is
58
                a tensor, the affine matrix is an identity and a tensor will be
59
                also returned.
60
        """
61
        if torch.rand(1).item() > self.probability:
62
            return data
63
64
        is_tensor = is_array = is_dict = is_image = is_sitk = is_nib = False
65
66
        if isinstance(data, nib.Nifti1Image):
67
            tensor = data.get_fdata(dtype=np.float32)
68
            data = ScalarImage(tensor=tensor, affine=data.affine)
69
            sample = self._get_subject_from_image(data)
70
            is_nib = True
71
        elif isinstance(data, (np.ndarray, torch.Tensor)):
72
            sample = self.parse_tensor(data)
73
            is_array = isinstance(data, np.ndarray)
74
            is_tensor = True
75
        elif isinstance(data, Image):
76
            sample = self._get_subject_from_image(data)
77
            is_image = True
78
        elif isinstance(data, Subject):
79
            sample = data
80
        elif isinstance(data, sitk.Image):
81
            sample = self._get_subject_from_sitk_image(data)
82
            is_sitk = True
83
        elif isinstance(data, dict):  # e.g. Eisen or MONAI dicts
84
            if self.keys is None:
85
                message = (
86
                    'If input is a dictionary, a value for "keys" must be'
87
                    ' specified when instantiating the transform'
88
                )
89
                raise RuntimeError(message)
90
            sample = self._get_subject_from_dict(data, self.keys)
91
            is_dict = True
92
        self.parse_sample(sample)
0 ignored issues
show
introduced by
The variable sample does not seem to be defined for all execution paths.
Loading history...
93
94
        if self.copy:
95
            sample = copy.copy(sample)
96
97
        with np.errstate(all='raise'):
98
            transformed = self.apply_transform(sample)
99
100
        for image in transformed.get_images(intensity_only=False):
101
            ndim = image[DATA].ndim
102
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
103
104
        if is_tensor or is_sitk:
105
            image = transformed[self.image_name]
106
            transformed = image[DATA]
107
            if is_array:
108
                transformed = transformed.numpy()
109
            elif is_sitk:
110
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
111
        elif is_image:
112
            transformed = transformed[IMAGE_NAME]
113
        elif is_dict:
114
            transformed = dict(transformed)
115
            for key, value in transformed.items():
116
                if isinstance(value, Image):
117
                    transformed[key] = value.data
118
        elif is_nib:
119
            image = transformed[IMAGE_NAME]
120
            data = image[DATA]
121
            if len(data) > 1:
122
                message = (
123
                    'Multichannel images not supported for input of type'
124
                    ' nibabel.nifti.Nifti1Image'
125
                )
126
                raise RuntimeError(message)
127
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
128
        return transformed
129
130
    @abstractmethod
131
    def apply_transform(self, sample: Subject):
132
        raise NotImplementedError
133
134
    @staticmethod
135
    def parse_range(
136
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
137
            name: str,
138
            min_constraint: TypeNumber = None,
139
            max_constraint: TypeNumber = None,
140
            type_constraint: type = None,
141
            ) -> Tuple[TypeNumber, TypeNumber]:
142
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
143
144
        Args:
145
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
146
                where :math:`n_{min} \leq n_{max}`.
147
                If a single positive number :math:`n` is provided,
148
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
149
            name: Name of the parameter, so that an informative error message
150
                can be printed.
151
            min_constraint: Minimal value that :math:`n_{min}` can take,
152
                default is None, i.e. there is no minimal value.
153
            max_constraint: Maximal value that :math:`n_{max}` can take,
154
                default is None, i.e. there is no maximal value.
155
            type_constraint: Precise type that :math:`n_{max}` and
156
                :math:`n_{min}` must take.
157
158
        Returns:
159
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
160
161
        Raises:
162
            ValueError: if :attr:`nums_range` is negative
163
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
164
            ValueError: if :math:`n_{max} \lt n_{min}`
165
            ValueError: if :attr:`min_constraint` is not None and
166
                :math:`n_{min}` is smaller than :attr:`min_constraint`
167
            ValueError: if :attr:`max_constraint` is not None and
168
                :math:`n_{max}` is greater than :attr:`max_constraint`
169
            ValueError: if :attr:`type_constraint` is not None and
170
                :math:`n_{max}` and :math:`n_{max}` are not of type
171
                :attr:`type_constraint`.
172
        """
173
        if isinstance(nums_range, numbers.Number):
174
            if nums_range < 0:
175
                raise ValueError(
176
                    f'If {name} is a single number,'
177
                    f' it must be positive, not {nums_range}')
178
            if min_constraint is not None and nums_range < min_constraint:
179
                raise ValueError(
180
                    f'If {name} is a single number, it must be greater'
181
                    f' than {min_constraint}, not {nums_range}'
182
                )
183
            if max_constraint is not None and nums_range > max_constraint:
184
                raise ValueError(
185
                    f'If {name} is a single number, it must be smaller'
186
                    f' than {max_constraint}, not {nums_range}'
187
                )
188
            if type_constraint is not None:
189
                if not isinstance(nums_range, type_constraint):
190
                    raise ValueError(
191
                        f'If {name} is a single number, it must be of'
192
                        f' type {type_constraint}, not {nums_range}'
193
                    )
194
            min_range = -nums_range if min_constraint is None else nums_range
195
            return (min_range, nums_range)
196
197
        try:
198
            min_degree, max_degree = nums_range
199
        except (TypeError, ValueError):
200
            raise ValueError(
201
                f'If {name} is not a single number, it must be'
202
                f' a sequence of len 2, not {nums_range}'
203
            )
204
205
        min_is_number = isinstance(min_degree, numbers.Number)
206
        max_is_number = isinstance(max_degree, numbers.Number)
207
        if not min_is_number or not max_is_number:
208
            message = (
209
                f'{name} values must be numbers, not {nums_range}')
210
            raise ValueError(message)
211
212
        if min_degree > max_degree:
213
            raise ValueError(
214
                f'If {name} is a sequence, the second value must be'
215
                f' equal or greater than the first, but it is {nums_range}')
216
217
        if min_constraint is not None and min_degree < min_constraint:
218
            raise ValueError(
219
                f'If {name} is a sequence, the first value must be greater'
220
                f' than {min_constraint}, but it is {min_degree}'
221
            )
222
223
        if max_constraint is not None and max_degree > max_constraint:
224
            raise ValueError(
225
                f'If {name} is a sequence, the second value must be smaller'
226
                f' than {max_constraint}, but it is {max_degree}'
227
            )
228
229
        if type_constraint is not None:
230
            min_type_ok = isinstance(min_degree, type_constraint)
231
            max_type_ok = isinstance(max_degree, type_constraint)
232
            if not min_type_ok or not max_type_ok:
233
                raise ValueError(
234
                    f'If "{name}" is a sequence, its values must be of'
235
                    f' type "{type_constraint}", not "{type(nums_range)}"'
236
                )
237
        return nums_range
238
239
    @staticmethod
240
    def parse_probability(probability: float) -> float:
241
        is_number = isinstance(probability, numbers.Number)
242
        if not (is_number and 0 <= probability <= 1):
243
            message = (
244
                'Probability must be a number in [0, 1],'
245
                f' not {probability}'
246
            )
247
            raise ValueError(message)
248
        return probability
249
250
    @staticmethod
251
    def parse_sample(sample: Subject) -> None:
252
        if not isinstance(sample, Subject):
253
            message = (
254
                'Input to a transform must be a tensor or an instance'
255
                f' of torchio.Subject, not "{type(sample)}"'
256
            )
257
            raise RuntimeError(message)
258
259
    def parse_tensor(self, data: TypeData) -> Subject:
260
        if data.ndim != 4:
261
            message = (
262
                'The input must be a 4D tensor with dimensions'
263
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
264
            )
265
            raise ValueError(message)
266
        return self._get_subject_from_tensor(data)
267
268
    @staticmethod
269
    def parse_interpolation(interpolation: str) -> Interpolation:
270
        if isinstance(interpolation, Interpolation):
271
            message = (
272
                'Interpolation of type torchio.Interpolation'
273
                ' is deprecated, please use a string instead'
274
            )
275
            warnings.warn(message, FutureWarning)
276
        elif isinstance(interpolation, str):
277
            interpolation = interpolation.lower()
278
            supported_values = [key.name.lower() for key in Interpolation]
279
            if interpolation in supported_values:
280
                interpolation = getattr(Interpolation, interpolation.upper())
281
            else:
282
                message = (
283
                    f'Interpolation "{interpolation}" is not among'
284
                    f' the supported values: {supported_values}'
285
                )
286
                raise AttributeError(message)
287
        else:
288
            message = (
289
                'image_interpolation must be a string,'
290
                f' not {type(interpolation)}'
291
            )
292
            raise TypeError(message)
293
        return interpolation
294
295
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
296
        image = ScalarImage(tensor=tensor, channels_last=False)
297
        return Subject({self.image_name: image})
298
299
    @staticmethod
300
    def _get_subject_from_image(image: Image) -> Subject:
301
        subject = Subject({IMAGE_NAME: image})
302
        return subject
303
304
    @staticmethod
305
    def _get_subject_from_dict(
306
            data: dict,
307
            image_keys: List[str],
308
            ) -> Subject:
309
        subject_dict = {}
310
        for key, value in data.items():
311
            if key in image_keys:
312
                value = ScalarImage(tensor=value)
313
            subject_dict[key] = value
314
        return Subject(subject_dict)
315
316
    @staticmethod
317
    def _get_subject_from_sitk_image(image):
318
        tensor, affine = sitk_to_nib(image)
319
        return Subject({IMAGE_NAME: ScalarImage(tensor=tensor, affine=affine)})
320
321
    @staticmethod
322
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
323
        return nib_to_sitk(data, affine)
324
325
    @staticmethod
326
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
327
        return sitk_to_nib(image)
328
329
    @property
330
    def name(self):
331
        return self.__class__.__name__
332