torchio.transforms.augmentation.intensity.random_labels_to_image   C
last analyzed

Complexity

Total Complexity 54

Size/Duplication

Total Lines 484
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 54
eloc 246
dl 0
loc 484
rs 6.4799
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
F LabelsToImage.apply_transform() 0 67 14
A RandomLabelsToImage.parse_mean_and_std() 0 15 5
A RandomLabelsToImage.get_params() 0 13 3
A LabelsToImage.__init__() 0 26 1
A RandomLabelsToImage._guess_label_key() 0 10 5
A RandomLabelsToImage.__init__() 0 28 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 19 4
A RandomLabelsToImage.parse_gaussian_parameters() 0 17 2
B RandomLabelsToImage.apply_transform() 0 51 6
A LabelsToImage.generate_tissue() 0 9 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_label_key() 0 5 3
A _check_mean_and_std_length() 0 22 5
A _parse_used_labels() 0 14 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
from collections.abc import Sequence
4
5
import torch
6
7
from ....data.image import LabelMap
8
from ....data.image import ScalarImage
9
from ....data.subject import Subject
10
from ....types import TypeData
11
from ....types import TypeRangeFloat
12
from ....utils import check_sequence
13
from ...intensity_transform import IntensityTransform
14
from .. import RandomTransform
15
16
17
class RandomLabelsToImage(RandomTransform, IntensityTransform):
18
    r"""Randomly generate an image from a segmentation.
19
20
    Based on the work by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_
21
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast`_.
22
23
    .. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html
24
25
    .. _Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast: https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18
26
27
    .. plot::
28
29
        import torch
30
        import torchio as tio
31
        torch.manual_seed(42)
32
        colin = tio.datasets.Colin27(2008)
33
        label_map = colin.cls
34
        colin.remove_image('t1')
35
        colin.remove_image('t2')
36
        colin.remove_image('pd')
37
        downsample = tio.Resample(1)
38
        blurring_transform = tio.RandomBlur(std=0.6)
39
        create_synthetic_image = tio.RandomLabelsToImage(
40
            image_key='synthetic',
41
            ignore_background=True,
42
        )
43
        transform = tio.Compose((
44
            downsample,
45
            create_synthetic_image,
46
            blurring_transform,
47
        ))
48
        colin_synth = transform(colin)
49
        colin_synth.plot()
50
51
    Args:
52
        label_key: String designating the label map in the subject
53
            that will be used to generate the new image.
54
        used_labels: Sequence of integers designating the labels used
55
            to generate the new image. If categorical encoding is used,
56
            :attr:`label_channels` refers to the values of the
57
            categorical encoding. If one hot encoding or partial-volume
58
            label maps are used, :attr:`label_channels` refers to the
59
            channels of the label maps.
60
            Default uses all labels. Missing voxels will be filled with zero
61
            or with voxels from an already existing volume,
62
            see :attr:`image_key`.
63
        image_key: String designating the key to which the new volume will be
64
            saved. If this key corresponds to an already existing volume,
65
            missing voxels will be filled with the corresponding values
66
            in the original volume.
67
        mean: Sequence of means for each label.
68
            For each value :math:`v`, if a tuple :math:`(a, b)` is
69
            provided then :math:`v \sim \mathcal{U}(a, b)`.
70
            If ``None``, :attr:`default_mean` range will be used for every
71
            label.
72
            If not ``None`` and :attr:`label_channels` is not ``None``,
73
            :attr:`mean` and :attr:`label_channels` must have the
74
            same length.
75
        std: Sequence of standard deviations for each label.
76
            For each value :math:`v`, if a tuple :math:`(a, b)` is
77
            provided then :math:`v \sim \mathcal{U}(a, b)`.
78
            If ``None``, :attr:`default_std` range will be used for every
79
            label.
80
            If not ``None`` and :attr:`label_channels` is not ``None``,
81
            :attr:`std` and :attr:`label_channels` must have the
82
            same length.
83
        default_mean: Default mean range.
84
        default_std: Default standard deviation range.
85
        discretize: If ``True``, partial-volume label maps will be discretized.
86
            Does not have any effects if not using partial-volume label maps.
87
            Discretization is done taking the class of the highest value per
88
            voxel in the different partial-volume label maps using
89
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
90
        ignore_background: If ``True``, input voxels labeled as ``0`` will not
91
            be modified.
92
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
93
            keyword arguments.
94
95
    .. tip:: It is recommended to blur the new images in order to simulate
96
        partial volume effects at the borders of the synthetic structures. See
97
        :class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`.
98
99
    Example:
100
        >>> import torchio as tio
101
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
102
        >>> # Using the default parameters
103
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
104
        >>> # Using custom mean and std
105
        >>> transform = tio.RandomLabelsToImage(
106
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
107
        ... )
108
        >>> # Discretizing the partial volume maps and blurring the result
109
        >>> simulation_transform = tio.RandomLabelsToImage(
110
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
111
        ... )
112
        >>> blurring_transform = tio.RandomBlur(std=0.3)
113
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
114
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
115
        >>> # Filling holes of the simulated image with the original T1 image
116
        >>> rescale_transform = tio.RescaleIntensity(
117
        ...     out_min_max=(0, 1), percentiles=(1, 99))   # Rescale intensity before filling holes
118
        >>> simulation_transform = tio.RandomLabelsToImage(
119
        ...     label_key='tissues',
120
        ...     image_key='t1',
121
        ...     used_labels=[0, 1]
122
        ... )
123
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
124
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image
125
126
    .. seealso:: :class:`~torchio.transforms.preprocessing.label.remap_labels.RemapLabels`.
127
    """
128
129
    def __init__(
130
        self,
131
        label_key: str | None = None,
132
        used_labels: Sequence[int] | None = None,
133
        image_key: str = 'image_from_labels',
134
        mean: Sequence[TypeRangeFloat] | None = None,
135
        std: Sequence[TypeRangeFloat] | None = None,
136
        default_mean: TypeRangeFloat = (0.1, 0.9),
137
        default_std: TypeRangeFloat = (0.01, 0.1),
138
        discretize: bool = False,
139
        ignore_background: bool = False,
140
        **kwargs,
141
    ):
142
        super().__init__(**kwargs)
143
        self.label_key = _parse_label_key(label_key)
144
        self.used_labels = _parse_used_labels(used_labels)  # type: ignore[arg-type]
145
        self.mean, self.std = self.parse_mean_and_std(mean, std)  # type: ignore[arg-type,assignment]
146
        self.default_mean = self.parse_gaussian_parameter(
147
            default_mean,
148
            'default_mean',
149
        )
150
        self.default_std = self.parse_gaussian_parameter(
151
            default_std,
152
            'default_std',
153
        )
154
        self.image_key = image_key
155
        self.discretize = discretize
156
        self.ignore_background = ignore_background
157
158
    def parse_mean_and_std(
159
        self,
160
        mean: Sequence[TypeRangeFloat],
161
        std: Sequence[TypeRangeFloat],
162
    ) -> tuple[list[TypeRangeFloat], list[TypeRangeFloat]]:
163
        if mean is not None:
164
            mean = self.parse_gaussian_parameters(mean, 'mean')
165
        if std is not None:
166
            std = self.parse_gaussian_parameters(std, 'std')
167
        if mean is not None and std is not None:
168
            message = (
169
                'If both "mean" and "std" are defined they must have the samelength'
170
            )
171
            assert len(mean) == len(std), message
172
        return mean, std
173
174
    def parse_gaussian_parameters(
175
        self,
176
        params: Sequence[TypeRangeFloat],
177
        name: str,
178
    ) -> list[TypeRangeFloat]:
179
        check_sequence(params, name)
180
        params = [
181
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
182
            for i, p in enumerate(params)
183
        ]
184
        if self.used_labels is not None:
185
            message = (
186
                f'If both "{name}" and "used_labels" are defined, '
187
                'they must have the same length'
188
            )
189
            assert len(params) == len(self.used_labels), message
190
        return params
191
192
    @staticmethod
193
    def parse_gaussian_parameter(
194
        nums_range: TypeRangeFloat,
195
        name: str,
196
    ) -> tuple[float, float]:
197
        if isinstance(nums_range, (int, float)):
198
            return nums_range, nums_range
199
200
        if len(nums_range) != 2:
201
            raise ValueError(
202
                f'If {name} is a sequence, it must be of len 2, not {nums_range}',
203
            )
204
        min_value, max_value = nums_range
205
        if min_value > max_value:
206
            raise ValueError(
207
                f'If {name} is a sequence, the second value must be'
208
                f' equal or greater than the first, not {nums_range}',
209
            )
210
        return min_value, max_value
211
212
    def _guess_label_key(self, subject: Subject) -> None:
213
        if self.label_key is None:
214
            iterable = subject.get_images_dict(intensity_only=False).items()
215
            for name, image in iterable:
216
                if isinstance(image, LabelMap):
217
                    self.label_key = name
218
                    break
219
            else:
220
                message = f'No label maps found in subject: {subject}'
221
                raise RuntimeError(message)
222
223
    def apply_transform(self, subject: Subject) -> Subject:
224
        self._guess_label_key(subject)
225
226
        arguments = {
227
            'label_key': self.label_key,
228
            'mean': [],
229
            'std': [],
230
            'image_key': self.image_key,
231
            'used_labels': self.used_labels,
232
            'discretize': self.discretize,
233
            'ignore_background': self.ignore_background,
234
        }
235
236
        label_map = subject[self.label_key].data
237
238
        # Find out if we face a partial-volume image or a label map.
239
        # One-hot-encoded label map is considered as a partial-volume image
240
        all_discrete = label_map.eq(label_map.float().round()).all()
241
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
242
        is_discretized = all_discrete and same_num_dims
243
244
        if not is_discretized and self.discretize:
245
            # Take label with highest value in voxel
246
            max_label, label_map = label_map.max(dim=0, keepdim=True)
247
            # Remove values where all labels are 0 (i.e. missing labels)
248
            label_map[max_label == 0] = -1
249
            is_discretized = True
250
251
        if is_discretized:
252
            labels = label_map.unique().long().tolist()
253
            if -1 in labels:
254
                labels.remove(-1)
255
        else:
256
            labels = range(label_map.shape[0])
257
258
        # Raise error if mean and std are not defined for every label
259
        _check_mean_and_std_length(labels, self.mean, self.std)  # type: ignore[arg-type]
260
261
        for label in labels:
262
            mean, std = self.get_params(label)
263
            means = arguments['mean']
264
            stds = arguments['std']
265
            assert isinstance(means, list)
266
            assert isinstance(stds, list)
267
            means.append(mean)
268
            stds.append(std)
269
270
        transform = LabelsToImage(**self.add_base_args(arguments))
271
        transformed = transform(subject)
272
        assert isinstance(transformed, Subject)
273
        return transformed
274
275
    def get_params(self, label: int) -> tuple[float, float]:
276
        if self.mean is None:
277
            mean_range = self.default_mean
278
        else:
279
            assert isinstance(self.mean, Sequence)
280
            mean_range = self.mean[label]
281
        if self.std is None:
282
            std_range = self.default_std
283
        else:
284
            std_range = self.std[label]
285
        mean = self.sample_uniform(*mean_range)  # type: ignore[misc]
286
        std = self.sample_uniform(*std_range)  # type: ignore[misc]
287
        return mean, std
288
289
290
class LabelsToImage(IntensityTransform):
291
    r"""Generate an image from a segmentation.
292
293
    Args:
294
        label_key: String designating the label map in the subject
295
            that will be used to generate the new image.
296
        used_labels: Sequence of integers designating the labels used
297
            to generate the new image. If categorical encoding is used,
298
            :attr:`label_channels` refers to the values of the
299
            categorical encoding. If one hot encoding or partial-volume
300
            label maps are used, :attr:`label_channels` refers to the
301
            channels of the label maps.
302
            Default uses all labels. Missing voxels will be filled with zero
303
            or with voxels from an already existing volume,
304
            see :attr:`image_key`.
305
        image_key: String designating the key to which the new volume will be
306
            saved. If this key corresponds to an already existing volume,
307
            missing voxels will be filled with the corresponding values
308
            in the original volume.
309
        mean: Sequence of means for each label.
310
            If not ``None`` and :attr:`label_channels` is not ``None``,
311
            :attr:`mean` and :attr:`label_channels` must have the
312
            same length.
313
        std: Sequence of standard deviations for each label.
314
            If not ``None`` and :attr:`label_channels` is not ``None``,
315
            :attr:`std` and :attr:`label_channels` must have the
316
            same length.
317
        discretize: If ``True``, partial-volume label maps will be discretized.
318
            Does not have any effects if not using partial-volume label maps.
319
            Discretization is done taking the class of the highest value per
320
            voxel in the different partial-volume label maps using
321
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
322
        ignore_background: If ``True``, input voxels labeled as ``0`` will not
323
            be modified.
324
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
325
            keyword arguments.
326
327
    .. note:: It is recommended to blur the new images to make the result more
328
        realistic. See
329
        :class:`~torchio.transforms.augmentation.RandomBlur`.
330
    """
331
332
    def __init__(
333
        self,
334
        label_key: str,
335
        mean: Sequence[float] | None,
336
        std: Sequence[float] | None,
337
        image_key: str = 'image_from_labels',
338
        used_labels: Sequence[int] | None = None,
339
        ignore_background: bool = False,
340
        discretize: bool = False,
341
        **kwargs,
342
    ):
343
        super().__init__(**kwargs)
344
        self.label_key = _parse_label_key(label_key)
345
        self.used_labels = _parse_used_labels(used_labels)
346
        self.mean, self.std = mean, std  # type: ignore[assignment]
347
        self.image_key = image_key
348
        self.ignore_background = ignore_background
349
        self.discretize = discretize
350
        self.args_names = [
351
            'label_key',
352
            'mean',
353
            'std',
354
            'image_key',
355
            'used_labels',
356
            'ignore_background',
357
            'discretize',
358
        ]
359
360
    def apply_transform(self, subject: Subject) -> Subject:
361
        original_image = subject.get(self.image_key)
362
363
        label_map_image = subject[self.label_key]
364
        label_map = label_map_image.data
365
        affine = label_map_image.affine
366
367
        # Find out if we face a partial-volume image or a label map.
368
        # One-hot-encoded label map is considered as a partial-volume image
369
        all_discrete = label_map.eq(label_map.float().round()).all()
370
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
371
        is_discretized = all_discrete and same_num_dims
372
373
        if not is_discretized and self.discretize:
374
            # Take label with highest value in voxel
375
            max_label, label_map = label_map.max(dim=0, keepdim=True)
376
            # Remove values where all labels are 0 (i.e. missing labels)
377
            label_map[max_label == 0] = -1
378
            is_discretized = True
379
380
        tissues = torch.zeros(1, *label_map_image.spatial_shape).float()
381
        if is_discretized:
382
            labels_in_image = label_map.unique().long().tolist()
383
            if -1 in labels_in_image:
384
                labels_in_image.remove(-1)
385
        else:
386
            labels_in_image = range(label_map.shape[0])
387
388
        # Raise error if mean and std are not defined for every label
389
        _check_mean_and_std_length(
390
            labels_in_image,
391
            self.mean,  # type: ignore[arg-type]
392
            self.std,
393
        )
394
395
        for i, label in enumerate(labels_in_image):
396
            if label == 0 and self.ignore_background:
397
                continue
398
            if self.used_labels is None or label in self.used_labels:
399
                assert isinstance(self.mean, Sequence)
400
                assert isinstance(self.std, Sequence)
401
                mean = self.mean[i]
402
                std = self.std[i]
403
                if is_discretized:
404
                    mask = label_map == label
405
                else:
406
                    mask = label_map[label]
407
                tissues += self.generate_tissue(mask, mean, std)
408
409
            else:
410
                # Modify label map to easily compute background mask
411
                if is_discretized:
412
                    label_map[label_map == label] = -1
413
                else:
414
                    label_map[label] = 0
415
416
        final_image = ScalarImage(affine=affine, tensor=tissues)
417
418
        if original_image is not None:
419
            if is_discretized:
420
                bg_mask = label_map == -1
421
            else:
422
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
423
            final_image.data[bg_mask] = original_image.data[bg_mask].float()
424
425
        subject.add_image(final_image, self.image_key)
426
        return subject
427
428
    @staticmethod
429
    def generate_tissue(
430
        data: TypeData,
431
        mean: float,
432
        std: float,
433
    ) -> TypeData:
434
        # Create the simulated tissue using a gaussian random variable
435
        gaussian = torch.randn(data.shape) * std + mean
436
        return gaussian * data
437
438
439
def _parse_label_key(label_key: str | None) -> str | None:
440
    if label_key is not None and not isinstance(label_key, str):
441
        message = f'"label_key" must be a string or None, not {type(label_key)}'
442
        raise TypeError(message)
443
    return label_key
444
445
446
def _parse_used_labels(
447
    used_labels: Sequence[int] | None,
448
) -> Sequence[int] | None:
449
    if used_labels is None:
450
        return None
451
    check_sequence(used_labels, 'used_labels')
452
    for e in used_labels:
453
        if not isinstance(e, int):
454
            message = (
455
                'Items in "used_labels" must be integers,'
456
                f' but some are not: {used_labels}'
457
            )
458
            raise ValueError(message)
459
    return used_labels
460
461
462
def _check_mean_and_std_length(
463
    labels: Sequence[int],
464
    means: Sequence[TypeRangeFloat] | None,
465
    stds: Sequence[TypeRangeFloat] | None,
466
) -> None:
467
    num_labels = len(labels)
468
    if means is not None:
469
        num_means = len(means)
470
        message = (
471
            '"mean" must define a value for each label but length of "mean"'
472
            f' is {num_means} while {num_labels} labels were found'
473
        )
474
        if num_means != num_labels:
475
            raise RuntimeError(message)
476
    if stds is not None:
477
        num_stds = len(stds)
478
        message = (
479
            '"std" must define a value for each label but length of "std"'
480
            f' is {num_stds} while {num_labels} labels were found'
481
        )
482
        if num_stds != num_labels:
483
            raise RuntimeError(message)
484