Passed
Pull Request — master (#353)
by Fernando
01:17
created

Blur.apply_transform()   A

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 15
nop 2
dl 0
loc 16
rs 9.65
c 0
b 0
f 0
1
from collections import defaultdict
2
from typing import Union, Tuple, Optional, List, Dict
3
4
import torch
5
import numpy as np
6
import scipy.ndimage as ndi
7
8
from ....utils import to_tuple
9
from ....torchio import DATA, TypeData, TypeTripletFloat, TypeSextetFloat
10
from ....data.subject import Subject
11
from ... import IntensityTransform
12
from .. import RandomTransform
13
14
15
class RandomBlur(RandomTransform, IntensityTransform):
16
    r"""Blur an image using a random-sized Gaussian filter.
17
18
    Args:
19
        std: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` representing the
20
            ranges (in mm) of the standard deviations
21
            :math:`(\sigma_1, \sigma_2, \sigma_3)` of the Gaussian kernels used
22
            to blur the image along each axis, where
23
            :math:`\sigma_i \sim \mathcal{U}(a_i, b_i)`.
24
            If two values :math:`(a, b)` are provided,
25
            then :math:`\sigma_i \sim \mathcal{U}(a, b)`.
26
            If only one value :math:`x` is provided,
27
            then :math:`\sigma_i \sim \mathcal{U}(0, x)`.
28
            If three values :math:`(x_1, x_2, x_3)` are provided,
29
            then :math:`\sigma_i \sim \mathcal{U}(0, x_i)`.
30
        p: Probability that this transform will be applied.
31
        keys: See :py:class:`~torchio.transforms.Transform`.
32
    """
33
    def __init__(
34
            self,
35
            std: Union[float, Tuple[float, float]] = (0, 2),
36
            p: float = 1,
37
            keys: Optional[List[str]] = None,
38
            ):
39
        super().__init__(p=p, keys=keys)
40
        self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0)
41
42
    def apply_transform(self, subject: Subject) -> Subject:
43
        arguments = defaultdict(dict)
44
        for name, image in self.get_images_dict(subject).items():
45
            stds = [self.get_params(self.std_ranges) for _ in image.data]
46
            arguments['std'][name] = stds
47
        transform = Blur(**arguments)
48
        transformed = transform(subject)
49
        return transformed
50
51
    def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat:
52
        std = self.sample_uniform_sextet(std_ranges)
53
        return std
54
55
56
class Blur(IntensityTransform):
57
    r"""Blur an image using a Gaussian filter.
58
59
    Args:
60
        std: Tuple :math:`(\sigma_1, \sigma_2, \sigma_3)` representing the
61
            the standard deviations (in mm) of the standard deviations
62
            of the Gaussian kernels used to blur the image along each axis.
63
        keys: See :py:class:`~torchio.transforms.Transform`.
64
    """
65
    def __init__(
66
            self,
67
            std: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
68
            keys: Optional[List[str]] = None,
69
            ):
70
        super().__init__(keys=keys)
71
        self.std = std
72
        self.args_names = ('std',)
73
74
    def apply_transform(self, subject: Subject) -> Subject:
75
        std = self.std
76
        for name, image in self.get_images_dict(subject).items():
77
            if self.arguments_are_dict():
78
                std = self.std[name]
79
            stds = to_tuple(std, length=len(image.data))
80
            transformed_tensors = []
81
            for std, tensor in zip(stds, image.data):
82
                transformed_tensor = blur(
83
                    tensor,
84
                    image.spacing,
85
                    std,
86
                )
87
                transformed_tensors.append(transformed_tensor)
88
            image[DATA] = torch.stack(transformed_tensors)
89
        return subject
90
91
92
def blur(
93
        data: TypeData,
94
        spacing: TypeTripletFloat,
95
        std_voxel: TypeTripletFloat,
96
        ) -> torch.Tensor:
97
    assert data.ndim == 3
98
    std_physical = np.array(std_voxel) / np.array(spacing)
99
    blurred = ndi.gaussian_filter(data, std_physical)
100
    tensor = torch.from_numpy(blurred)
101
    return tensor
102