Passed
Pull Request — master (#222)
by Fernando
02:36
created

torchio.transforms.augmentation.intensity.random_labels_to_image   A

Complexity

Total Complexity 39

Size/Duplication

Total Lines 279
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 165
dl 0
loc 279
rs 9.28
c 0
b 0
f 0
wmc 39

8 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomLabelsToImage.get_params() 0 13 2
A RandomLabelsToImage.__init__() 0 21 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
B RandomLabelsToImage.parse_gaussian_parameters() 0 29 7
C RandomLabelsToImage.apply_transform() 0 49 9
A RandomLabelsToImage.parse_pv_label_maps() 0 20 4
A RandomLabelsToImage.generate_tissue() 0 10 1
C RandomLabelsToImage.parse_keys() 0 35 11
1
2
from typing import Union, Tuple, Optional, Dict, Sequence
3
import torch
4
import numpy as np
5
from ....torchio import DATA, TypeData, TypeRangeFloat, TypeNumber, AFFINE
6
from ....data.subject import Subject
7
from ....data.image import ScalarImage
8
from .. import RandomTransform
9
10
11
TypeGaussian = Optional[Dict[Union[str, TypeNumber], Dict[str, TypeRangeFloat]]]
12
13
14
class RandomLabelsToImage(RandomTransform):
15
    r"""Generate an image from a segmentation.
16
17
    Based on the work by `Billot et al., A Learning Strategy for
18
    Contrast-agnostic MRI Segmentation <https://arxiv.org/abs/2003.01995>`_.
19
20
    Args:
21
        label_key: String designating the label map in the sample
22
            that will be used to generate the new image.
23
            Cannot be set at the same time as :py:attr:`pv_label_keys`.
24
        pv_label_keys: Sequence of strings designating the partial-volume
25
            label maps in the sample that will be used to generate the new
26
            image. Cannot be set at the same time as :py:attr:`label_key`.
27
        image_key: String designating the key to which the new volume will be
28
            saved. If this key corresponds to an already existing volume,
29
            voxels that have a value of 0 in the label maps will be filled with
30
            the corresponding values in the original volume.
31
        gaussian_parameters: Dictionary containing the mean and standard
32
            deviation for each label. For each value :math:`v`, if a tuple
33
            :math:`(a, b)` is provided then :math:`v \sim \mathcal{U}(a, b)`.
34
            If no value is given for a label, :py:attr:`default_mean` and
35
            :py:attr:`default_std` ranges will be used.
36
        default_mean: Default mean range.
37
        default_std: Default standard deviation range.
38
        binarize: If ``True``, partial-volume label maps will be binarized.
39
            Does not have any effects if not using partial-volume label maps.
40
            Binarization is done taking the class of the highest value per voxel
41
            in the different partial-volume label maps.
42
        p: Probability that this transform will be applied.
43
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
44
45
    .. note:: It is recommended to blur the new images to make the result more
46
        realistic. See
47
        :py:class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`.
48
49
    Example:
50
        >>> import torchio
51
        >>> from torchio import RandomLabelsToImage, DATA, RescaleIntensity, Compose
52
        >>> from torchio.datasets import Colin27
53
        >>> colin = Colin27(2008)
54
        >>> # Using the default gaussian_parameters
55
        >>> transform = RandomLabelsToImage(label_key='cls')
56
        >>> # Using custom gaussian_parameters
57
        >>> label_values = colin['cls'][DATA].unique().round().long()
58
        >>> gaussian_parameters = {
59
        ...     label: {
60
        ...         'mean': i / len(label_values),
61
        ...         'std': 0.01
62
        ...     }
63
        ...     for i, label in enumerate(label_values)
64
        ... }
65
        >>> transform = RandomLabelsToImage(label_key='cls', gaussian_parameters=gaussian_parameters)
66
        >>> transformed = transform(colin)  # colin has a new key 'image' with the simulated image
67
        >>> # Filling holes of the simulated image with the original T1 image
68
        >>> rescale_transform = RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
69
        >>> simulation_transform = RandomLabelsToImage(
70
        ...     label_key='cls',
71
        ...     image_key='t1',
72
        ...     gaussian_parameters={0: {'mean': 0, 'std': 0}}
73
        ... )
74
        >>> transform = Compose([rescale_transform, simulation_transform])
75
        >>> transformed = transform(colin)  # colin's key 't1' has been replaced with the simulated image
76
    """
77
    def __init__(
78
            self,
79
            label_key: Optional[str] = None,
80
            pv_label_keys: Optional[Sequence[str]] = None,
81
            image_key: str = 'image',
82
            gaussian_parameters: TypeGaussian = None,
83
            default_mean: TypeRangeFloat = (0.1, 0.9),
84
            default_std: TypeRangeFloat = (0.01, 0.1),
85
            binarize: bool = False,
86
            p: float = 1,
87
            seed: Optional[int] = None,
88
            ):
89
        super().__init__(p=p, seed=seed)
90
        self.label_key, self.pv_label_keys = self.parse_keys(
91
            label_key, pv_label_keys)
92
        self.default_mean = self.parse_gaussian_parameter(default_mean, 'mean')
93
        self.default_std = self.parse_gaussian_parameter(default_std, 'std')
94
        self.gaussian_parameters = self.parse_gaussian_parameters(
95
            gaussian_parameters)
96
        self.image_key = image_key
97
        self.binarize = binarize
98
99
    @staticmethod
100
    def parse_keys(
101
            label_key: str,
102
            pv_label_keys: Sequence[str]
103
            ) -> (str, Sequence[str]):
104
        if label_key is not None and pv_label_keys is not None:
105
            message = (
106
                '"label_key" and "pv_label_keys" cannot be set at the same time'
107
            )
108
            raise ValueError(message)
109
        if label_key is None and pv_label_keys is None:
110
            message = 'One of "label_key" and "pv_label_keys" must be set'
111
            raise ValueError(message)
112
        if label_key is not None and not isinstance(label_key, str):
113
            message = f'"label_key" must be a string, not {type(label_key)}'
114
            raise TypeError(message)
115
        if pv_label_keys is not None:
116
            try:
117
                iter(pv_label_keys)
118
            except TypeError:
119
                message = (
120
                    '"pv_label_keys" must be a sequence of strings, '
121
                    f'not {pv_label_keys}'
122
                )
123
                raise TypeError(message)
124
            for key in pv_label_keys:
125
                if not isinstance(key, str):
126
                    message = (
127
                        f'Every key of "pv_label_keys" must be a string, '
128
                        f'found {type(key)}'
129
                    )
130
                    raise TypeError(message)
131
            pv_label_keys = list(pv_label_keys)
132
133
        return label_key, pv_label_keys
134
135
    def apply_transform(self, sample: Subject) -> dict:
136
        random_parameters_images_dict = {}
137
        original_image = sample.get(self.image_key)
138
139
        if self.pv_label_keys is not None:
140
            label_map, affine = self.parse_pv_label_maps(
141
                self.pv_label_keys, sample)
142
            n_labels, *image_shape = label_map.shape
143
            labels = self.pv_label_keys
144
            values = list(range(n_labels))
145
146
            if self.binarize:
147
                # Take label with highest value in voxel
148
                max_label, label_map = label_map.max(dim=0)
149
                # Remove values where all labels are 0
150
                label_map[max_label == 0] = -1
151
        else:
152
            label_map = sample[self.label_key][DATA][0]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
153
            affine = sample[self.label_key][AFFINE]
154
            image_shape = label_map.shape
155
            values = label_map.unique()
156
            labels = [int(key) for key in values.round()]
157
158
        tissues = torch.zeros(image_shape)
159
160
        for i, label in enumerate(labels):
161
            mean, std = self.get_params(label)
162
            if self.pv_label_keys is not None and not self.binarize:
163
                mask = label_map[i]
164
            else:
165
                mask = label_map == values[i]
166
            tissues += self.generate_tissue(mask, mean, std)
167
168
            random_parameters_images_dict[label] = {
169
                'mean': mean,
170
                'std': std
171
            }
172
173
        final_image = ScalarImage(affine=affine, tensor=tissues)
174
175
        if original_image is not None:
176
            if self.pv_label_keys is not None and not self.binarize:
177
                label_map = label_map.sum(dim=0)
178
            bg_mask = label_map.unsqueeze(0) <= 0
179
            final_image[DATA][bg_mask] = original_image[DATA][bg_mask]
180
181
        sample.add_image(final_image, self.image_key)
182
        sample.add_transform(self, random_parameters_images_dict)
183
        return sample
184
185
    def parse_gaussian_parameters(
186
            self,
187
            parameters: TypeGaussian,
188
            ) -> TypeGaussian:
189
        if parameters is None:
190
            parameters = {}
191
192
        if self.pv_label_keys is not None:
193
            if not set(self.pv_label_keys).issuperset(parameters.keys()):
194
                message = (
195
                    f'Found keys in gaussian parameters {parameters.keys()} '
196
                    f'not in pv_label_keys {self.pv_label_keys}'
197
                )
198
                raise KeyError(message)
199
200
        parsed_parameters = {}
201
202
        for label_key, dictionary in parameters.items():
203
            if 'mean' in dictionary:
204
                mean = self.parse_gaussian_parameter(dictionary['mean'], 'mean')
205
            else:
206
                mean = self.default_mean
207
            if 'std' in dictionary:
208
                std = self.parse_gaussian_parameter(dictionary['std'], 'std')
209
            else:
210
                std = self.default_std
211
            parsed_parameters.update({label_key: {'mean': mean, 'std': std}})
212
213
        return parsed_parameters
214
215
    @staticmethod
216
    def parse_gaussian_parameter(
217
            nums_range: TypeRangeFloat,
218
            name: str,
219
            ) -> Tuple[float, float]:
220
        if isinstance(nums_range, (int, float)):
221
            return nums_range, nums_range
222
223
        if len(nums_range) != 2:
224
            raise ValueError(
225
                f'If {name} is a sequence,'
226
                f' it must be of len 2, not {nums_range}')
227
        min_value, max_value = nums_range
228
        if min_value > max_value:
229
            raise ValueError(
230
                f'If {name} is a sequence, the second value must be'
231
                f' equal or greater than the first, not {nums_range}')
232
        return min_value, max_value
233
234
    @staticmethod
235
    def parse_pv_label_maps(
236
            pv_label_keys: Sequence[str],
237
            sample: dict,
238
            ) -> (TypeData, TypeData):
239
        try:
240
            label_map = torch.cat(
241
                [sample[key][DATA] for key in pv_label_keys],
242
                dim=0)
243
        except RuntimeError:
244
            message = 'Partial-volume label maps have different shapes'
245
            raise RuntimeError(message)
246
        affine = sample[pv_label_keys[0]][AFFINE]
247
        for key in pv_label_keys[1:]:
248
            if not np.array_equal(affine, sample[key][AFFINE]):
249
                message = (
250
                    'Partial-volume label maps have different affine matrices'
251
                )
252
                raise RuntimeWarning(message)
253
        return label_map, affine
254
255
    def get_params(
256
            self,
257
            label: Union[str, TypeNumber]
258
            ) -> Tuple[float, float]:
259
        if label in self.gaussian_parameters:
260
            mean_range = self.gaussian_parameters[label]['mean']
261
            std_range = self.gaussian_parameters[label]['std']
262
        else:
263
            mean_range, std_range = self.default_mean, self.default_std
264
265
        mean = torch.FloatTensor(1).uniform_(*mean_range).item()
266
        std = torch.FloatTensor(1).uniform_(*std_range).item()
267
        return mean, std
268
269
    @staticmethod
270
    def generate_tissue(
271
            data: TypeData,
272
            mean: float,
273
            std: float,
274
            ) -> TypeData:
275
        # Create the simulated tissue using a gaussian random variable
276
        data_shape = data.shape
277
        gaussian = torch.randn(data_shape) * std + mean
278
        return gaussian * data
279