Passed
Pull Request — master (#353)
by Fernando
01:10
created

torchio.transforms.augmentation.intensity.random_labels_to_image   C

Complexity

Total Complexity 54

Size/Duplication

Total Lines 416
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 218
dl 0
loc 416
rs 6.4799
c 0
b 0
f 0
wmc 54

11 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomLabelsToImage.__init__() 0 22 1
D LabelsToImage.apply_transform() 0 56 12
A RandomLabelsToImage.parse_mean_and_std() 0 16 5
A RandomLabelsToImage.get_params() 0 13 3
A LabelsToImage.__init__() 0 23 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
A LabelsToImage.get_params() 0 10 3
A RandomLabelsToImage.parse_gaussian_parameters() 0 17 2
D RandomLabelsToImage.apply_transform() 0 53 12
A LabelsToImage.generate_tissue() 0 10 1
A RandomLabelsToImage.check_mean_and_std_length() 0 13 3

2 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_label_key() 0 5 3
A _parse_used_labels() 0 12 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Tuple, Optional, Sequence, List
2
3
import torch
4
5
from ....torchio import DATA, TypeData, TypeRangeFloat
6
from ....utils import check_sequence
7
from ....data.subject import Subject
8
from ....data.image import ScalarImage, LabelMap
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomLabelsToImage(RandomTransform, IntensityTransform):
14
    r"""Randomly generate an image from a segmentation.
15
16
    Based on the works by Billot et al.: `A Learning Strategy for
17
    Contrast-agnostic MRI Segmentation <https://arxiv.org/abs/2003.01995>`_
18
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and
19
    Contrast <https://arxiv.org/abs/2004.10221>`_.
20
21
    Args:
22
        label_key: String designating the label map in the subject
23
            that will be used to generate the new image.
24
        used_labels: Sequence of integers designating the labels used
25
            to generate the new image. If categorical encoding is used,
26
            :py:attr:`label_channels` refers to the values of the
27
            categorical encoding. If one hot encoding or partial-volume
28
            label maps are used, :py:attr:`label_channels` refers to the
29
            channels of the label maps.
30
            Default uses all labels. Missing voxels will be filled with zero
31
            or with voxels from an already existing volume,
32
            see :py:attr:`image_key`.
33
        image_key: String designating the key to which the new volume will be
34
            saved. If this key corresponds to an already existing volume,
35
            missing voxels will be filled with the corresponding values
36
            in the original volume.
37
        mean: Sequence of means for each label.
38
            For each value :math:`v`, if a tuple :math:`(a, b)` is
39
            provided then :math:`v \sim \mathcal{U}(a, b)`.
40
            If ``None``, :py:attr:`default_mean` range will be used for every
41
            label.
42
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
43
            :py:attr:`mean` and :py:attr:`label_channels` must have the
44
            same length.
45
        std: Sequence of standard deviations for each label.
46
            For each value :math:`v`, if a tuple :math:`(a, b)` is
47
            provided then :math:`v \sim \mathcal{U}(a, b)`.
48
            If ``None``, :py:attr:`default_std` range will be used for every
49
            label.
50
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
51
            :py:attr:`std` and :py:attr:`label_channels` must have the
52
            same length.
53
        default_mean: Default mean range.
54
        default_std: Default standard deviation range.
55
        discretize: If ``True``, partial-volume label maps will be discretized.
56
            Does not have any effects if not using partial-volume label maps.
57
            Discretization is done taking the class of the highest value per
58
            voxel in the different partial-volume label maps using
59
            :py:func:`torch.argmax()` on the channel dimension (i.e. 0).
60
        p: Probability that this transform will be applied.
61
        keys: See :py:class:`~torchio.transforms.Transform`.
62
63
    .. note:: It is recommended to blur the new images to make the result more
64
        realistic. See
65
        :py:class:`~torchio.transforms.augmentation.RandomBlur`.
66
67
    Example:
68
        >>> import torchio as tio
69
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
70
        >>> # Using the default parameters
71
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
72
        >>> # Using custom mean and std
73
        >>> transform = tio.RandomLabelsToImage(
74
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
75
        ... )
76
        >>> # Discretizing the partial volume maps and blurring the result
77
        >>> simulation_transform = tio.RandomLabelsToImage(
78
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
79
        ... )
80
        >>> blurring_transform = tio.RandomBlur(std=0.3)
81
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
82
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
83
        >>> # Filling holes of the simulated image with the original T1 image
84
        >>> rescale_transform = tio.RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
85
        >>> simulation_transform = tio.RandomLabelsToImage(
86
        ...     label_key='tissues',
87
        ...     image_key='t1',
88
        ...     used_labels=[0, 1]
89
        ... )
90
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
91
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image
92
    """
93
    def __init__(
94
            self,
95
            label_key: Optional[str] = None,
96
            used_labels: Optional[Sequence[int]] = None,
97
            image_key: str = 'image_from_labels',
98
            mean: Optional[Sequence[TypeRangeFloat]] = None,
99
            std: Optional[Sequence[TypeRangeFloat]] = None,
100
            default_mean: TypeRangeFloat = (0.1, 0.9),
101
            default_std: TypeRangeFloat = (0.01, 0.1),
102
            discretize: bool = False,
103
            p: float = 1,
104
            keys: Optional[Sequence[str]] = None,
105
            ):
106
        super().__init__(p=p, keys=keys)
107
        self.label_key = _parse_label_key(label_key)
108
        self.used_labels = _parse_used_labels(used_labels)
109
        self.mean, self.std = self.parse_mean_and_std(mean, std)
110
        self.default_mean = self.parse_gaussian_parameter(
111
            default_mean, 'default_mean')
112
        self.default_std = self.parse_gaussian_parameter(default_std, 'default_std')
113
        self.image_key = image_key
114
        self.discretize = discretize
115
116
    def parse_mean_and_std(
117
            self,
118
            mean: Sequence[TypeRangeFloat],
119
            std: Sequence[TypeRangeFloat]
120
            ) -> (List[TypeRangeFloat], List[TypeRangeFloat]):
121
        if mean is not None:
122
            mean = self.parse_gaussian_parameters(mean, 'mean')
123
        if std is not None:
124
            std = self.parse_gaussian_parameters(std, 'std')
125
        if mean is not None and std is not None:
126
            message = (
127
                'If both "mean" and "std" are defined they must have the same'
128
                'length'
129
            )
130
            assert len(mean) == len(std), message
131
        return mean, std
132
133
    def parse_gaussian_parameters(
134
            self,
135
            params: Sequence[TypeRangeFloat],
136
            name: str
137
            ) -> List[TypeRangeFloat]:
138
        check_sequence(params, name)
139
        params = [
140
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
141
            for i, p in enumerate(params)
142
        ]
143
        if self.used_labels is not None:
144
            message = (
145
                f'If both "{name}" and "used_labels" are defined, '
146
                f'they must have the same length'
147
            )
148
            assert len(params) == len(self.used_labels), message
149
        return params
150
151
    @staticmethod
152
    def parse_gaussian_parameter(
153
            nums_range: TypeRangeFloat,
154
            name: str,
155
            ) -> Tuple[float, float]:
156
        if isinstance(nums_range, (int, float)):
157
            return nums_range, nums_range
158
159
        if len(nums_range) != 2:
160
            raise ValueError(
161
                f'If {name} is a sequence,'
162
                f' it must be of len 2, not {nums_range}')
163
        min_value, max_value = nums_range
164
        if min_value > max_value:
165
            raise ValueError(
166
                f'If {name} is a sequence, the second value must be'
167
                f' equal or greater than the first, not {nums_range}')
168
        return min_value, max_value
169
170
    def apply_transform(self, subject: Subject) -> Subject:
171
        if self.label_key is None:
172
            iterable = subject.get_images_dict(intensity_only=False).items()
173
            for name, image in iterable:
174
                if isinstance(image, LabelMap):
175
                    self.label_key = name
176
                    break
177
            else:
178
                raise RuntimeError(f'No label maps found in subject: {subject}')
179
180
        arguments = dict(
181
            label_key=self.label_key,
182
            mean=[],
183
            std=[],
184
            image_key=self.image_key,
185
            used_labels=self.used_labels,
186
            discretize=self.discretize,
187
        )
188
189
        label_map = subject[self.label_key][DATA]
190
191
        # Find out if we face a partial-volume image or a label map.
192
        # One-hot-encoded label map is considered as a partial-volume image
193
        all_discrete = label_map.eq(label_map.round()).all()
194
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
195
        is_discretized = all_discrete and same_num_dims
196
197
        if not is_discretized and self.discretize:
198
            # Take label with highest value in voxel
199
            max_label, label_map = label_map.max(dim=0, keepdim=True)
200
            # Remove values where all labels are 0 (i.e. missing labels)
201
            label_map[max_label == 0] = -1
202
            is_discretized = True
203
204
        if is_discretized:
205
            labels = label_map.unique().long().tolist()
206
            if -1 in labels:
207
                labels.remove(-1)
208
        else:
209
            labels = range(label_map.shape[0])
210
211
        # Raise error if mean and std are not defined for every label
212
        self.check_mean_and_std_length(labels)
213
214
        for label in labels:
215
            if self.used_labels is None or label in self.used_labels:
216
                mean, std = self.get_params(label)
217
                arguments['mean'].append(mean)
218
                arguments['std'].append(std)
219
220
        transform = LabelsToImage(**arguments)
221
        transformed = transform(subject)
222
        return transformed
223
224
    def check_mean_and_std_length(self, labels: Sequence):
225
        if self.mean is not None:
226
            message = (
227
                '"mean" must define a value for each label but length of "mean"'
228
                f' is {len(self.mean)} while {len(labels)} labels were found'
229
            )
230
            assert len(self.mean) == len(labels), message
231
        if self.std is not None:
232
            message = (
233
                '"std" must define a value for each label but length of "std"'
234
                f' is {len(self.std)} while {len(labels)} labels were found'
235
            )
236
            assert len(self.std) == len(labels), message
237
238
    def get_params(self, label: int) -> Tuple[float, float]:
239
        if self.mean is not None:
240
            mean_range = self.mean[label]
241
        else:
242
            mean_range = self.default_mean
243
        if self.std is not None:
244
            std_range = self.std[label]
245
        else:
246
            std_range = self.default_std
247
248
        mean = self.sample_uniform(*mean_range).item()
249
        std = self.sample_uniform(*std_range).item()
250
        return mean, std
251
252
253
class LabelsToImage(IntensityTransform):
254
    r"""Generate an image from a segmentation.
255
256
    Args:
257
        label_key: String designating the label map in the subject
258
            that will be used to generate the new image.
259
        used_labels: Sequence of integers designating the labels used
260
            to generate the new image. If categorical encoding is used,
261
            :py:attr:`label_channels` refers to the values of the
262
            categorical encoding. If one hot encoding or partial-volume
263
            label maps are used, :py:attr:`label_channels` refers to the
264
            channels of the label maps.
265
            Default uses all labels. Missing voxels will be filled with zero
266
            or with voxels from an already existing volume,
267
            see :py:attr:`image_key`.
268
        image_key: String designating the key to which the new volume will be
269
            saved. If this key corresponds to an already existing volume,
270
            missing voxels will be filled with the corresponding values
271
            in the original volume.
272
        mean: Sequence of means for each label.
273
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
274
            :py:attr:`mean` and :py:attr:`label_channels` must have the
275
            same length.
276
        std: Sequence of standard deviations for each label.
277
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
278
            :py:attr:`std` and :py:attr:`label_channels` must have the
279
            same length.
280
        discretize: If ``True``, partial-volume label maps will be discretized.
281
            Does not have any effects if not using partial-volume label maps.
282
            Discretization is done taking the class of the highest value per
283
            voxel in the different partial-volume label maps using
284
            :py:func:`torch.argmax()` on the channel dimension (i.e. 0).
285
        seed: Seed for the random number generator.
286
        keys: See :py:class:`~torchio.transforms.Transform`.
287
288
    .. note:: It is recommended to blur the new images to make the result more
289
        realistic. See
290
        :py:class:`~torchio.transforms.augmentation.RandomBlur`.
291
    """
292
    def __init__(
293
            self,
294
            label_key: str,
295
            mean: Optional[Sequence[float]],
296
            std: Optional[Sequence[float]],
297
            image_key: str = 'image_from_labels',
298
            used_labels: Optional[Sequence[int]] = None,
299
            discretize: bool = False,
300
            keys: Optional[List[str]] = None,
301
            ):
302
        super().__init__(keys=keys)
303
        self.label_key = _parse_label_key(label_key)
304
        self.used_labels = _parse_used_labels(used_labels)
305
        self.mean, self.std = mean, std
306
        self.image_key = image_key
307
        self.discretize = discretize
308
        self.args_names = (
309
            'label_key',
310
            'mean',
311
            'std',
312
            'image_key',
313
            'used_labels',
314
            'discretize',
315
        )
316
317
    def apply_transform(self, subject: Subject) -> Subject:
318
        original_image = subject.get(self.image_key)
319
320
        label_map = subject[self.label_key].data
321
        affine = subject[self.label_key].affine
322
323
        spatial_shape = label_map.shape[1:]
324
325
        # Find out if we face a partial-volume image or a label map.
326
        # One-hot-encoded label map is considered as a partial-volume image
327
        all_discrete = label_map.eq(label_map.round()).all()
328
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
329
        is_discretized = all_discrete and same_num_dims
330
331
        if not is_discretized and self.discretize:
332
            # Take label with highest value in voxel
333
            max_label, label_map = label_map.max(dim=0, keepdim=True)
334
            # Remove values where all labels are 0 (i.e. missing labels)
335
            label_map[max_label == 0] = -1
336
            is_discretized = True
337
338
        tissues = torch.zeros(1, *spatial_shape).float()
339
        if is_discretized:
340
            labels = label_map.unique().long().tolist()
341
            if -1 in labels:
342
                labels.remove(-1)
343
        else:
344
            labels = range(label_map.shape[0])
345
346
        for label in labels:
347
            if self.used_labels is None or label in self.used_labels:
348
                mean, std = self.get_params(label)
349
                if is_discretized:
350
                    mask = label_map == label
351
                else:
352
                    mask = label_map[label]
353
                tissues += self.generate_tissue(mask, mean, std)
354
355
            else:
356
                # Modify label map to easily compute background mask
357
                if is_discretized:
358
                    label_map[label_map == label] = -1
359
                else:
360
                    label_map[label] = 0
361
362
        final_image = ScalarImage(affine=affine, tensor=tissues)
363
364
        if original_image is not None:
365
            if is_discretized:
366
                bg_mask = label_map == -1
367
            else:
368
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
369
            final_image[DATA][bg_mask] = original_image[DATA][bg_mask]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
370
371
        subject.add_image(final_image, self.image_key)
372
        return subject
373
374
    def get_params(self, label: int) -> Tuple[float, float]:
375
        if self.mean is not None:
376
            mean = self.mean[label]
377
        else:
378
            mean = self.default_mean
379
        if self.std is not None:
380
            std = self.std[label]
381
        else:
382
            std = self.default_std
383
        return mean, std
384
385
    @staticmethod
386
    def generate_tissue(
387
            data: TypeData,
388
            mean: float,
389
            std: float,
390
            ) -> TypeData:
391
        # Create the simulated tissue using a gaussian random variable
392
        data_shape = data.shape
393
        gaussian = torch.randn(data_shape) * std + mean
394
        return gaussian * data
395
396
397
def _parse_label_key(label_key: Optional[str]) -> Optional[str]:
398
    if label_key is not None and not isinstance(label_key, str):
399
        message = f'"label_key" must be a string or None, not {type(label_key)}'
400
        raise TypeError(message)
401
    return label_key
402
403
404
def _parse_used_labels(used_labels: Sequence[int]) -> Sequence[int]:
405
    if used_labels is None:
406
        return None
407
    check_sequence(used_labels, 'used_labels')
408
    for e in used_labels:
409
        if not isinstance(e, int):
410
            message = (
411
                'Items in "used_labels" must be integers,'
412
                f' but some are not: {used_labels}'
413
            )
414
            raise ValueError(message)
415
    return used_labels
416