Passed
Pull Request — master (#353)
by Fernando
01:11
created

torchio.transforms.augmentation.intensity.random_labels_to_image   B

Complexity

Total Complexity 51

Size/Duplication

Total Lines 416
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 219
dl 0
loc 416
rs 7.92
c 0
b 0
f 0
wmc 51

9 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomLabelsToImage.__init__() 0 22 1
A RandomLabelsToImage.parse_mean_and_std() 0 16 5
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
A RandomLabelsToImage.parse_gaussian_parameters() 0 17 2
D LabelsToImage.apply_transform() 0 59 12
A RandomLabelsToImage.get_params() 0 12 3
A LabelsToImage.__init__() 0 23 1
C RandomLabelsToImage.apply_transform() 0 52 10
A LabelsToImage.generate_tissue() 0 10 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_label_key() 0 5 3
A _check_mean_and_std_length() 0 22 5
A _parse_used_labels() 0 12 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Tuple, Optional, Sequence, List
2
3
import torch
4
5
from ....torchio import DATA, TypeData, TypeRangeFloat
6
from ....utils import check_sequence
7
from ....data.subject import Subject
8
from ....data.image import ScalarImage, LabelMap
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomLabelsToImage(RandomTransform, IntensityTransform):
14
    r"""Randomly generate an image from a segmentation.
15
16
    Based on the works by Billot et al.: `A Learning Strategy for
17
    Contrast-agnostic MRI Segmentation <https://arxiv.org/abs/2003.01995>`_
18
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and
19
    Contrast <https://arxiv.org/abs/2004.10221>`_.
20
21
    Args:
22
        label_key: String designating the label map in the subject
23
            that will be used to generate the new image.
24
        used_labels: Sequence of integers designating the labels used
25
            to generate the new image. If categorical encoding is used,
26
            :py:attr:`label_channels` refers to the values of the
27
            categorical encoding. If one hot encoding or partial-volume
28
            label maps are used, :py:attr:`label_channels` refers to the
29
            channels of the label maps.
30
            Default uses all labels. Missing voxels will be filled with zero
31
            or with voxels from an already existing volume,
32
            see :py:attr:`image_key`.
33
        image_key: String designating the key to which the new volume will be
34
            saved. If this key corresponds to an already existing volume,
35
            missing voxels will be filled with the corresponding values
36
            in the original volume.
37
        mean: Sequence of means for each label.
38
            For each value :math:`v`, if a tuple :math:`(a, b)` is
39
            provided then :math:`v \sim \mathcal{U}(a, b)`.
40
            If ``None``, :py:attr:`default_mean` range will be used for every
41
            label.
42
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
43
            :py:attr:`mean` and :py:attr:`label_channels` must have the
44
            same length.
45
        std: Sequence of standard deviations for each label.
46
            For each value :math:`v`, if a tuple :math:`(a, b)` is
47
            provided then :math:`v \sim \mathcal{U}(a, b)`.
48
            If ``None``, :py:attr:`default_std` range will be used for every
49
            label.
50
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
51
            :py:attr:`std` and :py:attr:`label_channels` must have the
52
            same length.
53
        default_mean: Default mean range.
54
        default_std: Default standard deviation range.
55
        discretize: If ``True``, partial-volume label maps will be discretized.
56
            Does not have any effects if not using partial-volume label maps.
57
            Discretization is done taking the class of the highest value per
58
            voxel in the different partial-volume label maps using
59
            :py:func:`torch.argmax()` on the channel dimension (i.e. 0).
60
        p: Probability that this transform will be applied.
61
        keys: See :py:class:`~torchio.transforms.Transform`.
62
63
    .. note:: It is recommended to blur the new images to make the result more
64
        realistic. See
65
        :py:class:`~torchio.transforms.augmentation.RandomBlur`.
66
67
    Example:
68
        >>> import torchio as tio
69
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
70
        >>> # Using the default parameters
71
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
72
        >>> # Using custom mean and std
73
        >>> transform = tio.RandomLabelsToImage(
74
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
75
        ... )
76
        >>> # Discretizing the partial volume maps and blurring the result
77
        >>> simulation_transform = tio.RandomLabelsToImage(
78
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
79
        ... )
80
        >>> blurring_transform = tio.RandomBlur(std=0.3)
81
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
82
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
83
        >>> # Filling holes of the simulated image with the original T1 image
84
        >>> rescale_transform = tio.RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
85
        >>> simulation_transform = tio.RandomLabelsToImage(
86
        ...     label_key='tissues',
87
        ...     image_key='t1',
88
        ...     used_labels=[0, 1]
89
        ... )
90
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
91
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image
92
    """
93
    def __init__(
94
            self,
95
            label_key: Optional[str] = None,
96
            used_labels: Optional[Sequence[int]] = None,
97
            image_key: str = 'image_from_labels',
98
            mean: Optional[Sequence[TypeRangeFloat]] = None,
99
            std: Optional[Sequence[TypeRangeFloat]] = None,
100
            default_mean: TypeRangeFloat = (0.1, 0.9),
101
            default_std: TypeRangeFloat = (0.01, 0.1),
102
            discretize: bool = False,
103
            p: float = 1,
104
            keys: Optional[Sequence[str]] = None,
105
            ):
106
        super().__init__(p=p, keys=keys)
107
        self.label_key = _parse_label_key(label_key)
108
        self.used_labels = _parse_used_labels(used_labels)
109
        self.mean, self.std = self.parse_mean_and_std(mean, std)
110
        self.default_mean = self.parse_gaussian_parameter(
111
            default_mean, 'default_mean')
112
        self.default_std = self.parse_gaussian_parameter(default_std, 'default_std')
113
        self.image_key = image_key
114
        self.discretize = discretize
115
116
    def parse_mean_and_std(
117
            self,
118
            mean: Sequence[TypeRangeFloat],
119
            std: Sequence[TypeRangeFloat]
120
            ) -> (List[TypeRangeFloat], List[TypeRangeFloat]):
121
        if mean is not None:
122
            mean = self.parse_gaussian_parameters(mean, 'mean')
123
        if std is not None:
124
            std = self.parse_gaussian_parameters(std, 'std')
125
        if mean is not None and std is not None:
126
            message = (
127
                'If both "mean" and "std" are defined they must have the same'
128
                'length'
129
            )
130
            assert len(mean) == len(std), message
131
        return mean, std
132
133
    def parse_gaussian_parameters(
134
            self,
135
            params: Sequence[TypeRangeFloat],
136
            name: str
137
            ) -> List[TypeRangeFloat]:
138
        check_sequence(params, name)
139
        params = [
140
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
141
            for i, p in enumerate(params)
142
        ]
143
        if self.used_labels is not None:
144
            message = (
145
                f'If both "{name}" and "used_labels" are defined, '
146
                f'they must have the same length'
147
            )
148
            assert len(params) == len(self.used_labels), message
149
        return params
150
151
    @staticmethod
152
    def parse_gaussian_parameter(
153
            nums_range: TypeRangeFloat,
154
            name: str,
155
            ) -> Tuple[float, float]:
156
        if isinstance(nums_range, (int, float)):
157
            return nums_range, nums_range
158
159
        if len(nums_range) != 2:
160
            raise ValueError(
161
                f'If {name} is a sequence,'
162
                f' it must be of len 2, not {nums_range}')
163
        min_value, max_value = nums_range
164
        if min_value > max_value:
165
            raise ValueError(
166
                f'If {name} is a sequence, the second value must be'
167
                f' equal or greater than the first, not {nums_range}')
168
        return min_value, max_value
169
170
    def apply_transform(self, subject: Subject) -> Subject:
171
        if self.label_key is None:
172
            iterable = subject.get_images_dict(intensity_only=False).items()
173
            for name, image in iterable:
174
                if isinstance(image, LabelMap):
175
                    self.label_key = name
176
                    break
177
            else:
178
                raise RuntimeError(f'No label maps found in subject: {subject}')
179
180
        arguments = dict(
181
            label_key=self.label_key,
182
            mean=[],
183
            std=[],
184
            image_key=self.image_key,
185
            used_labels=self.used_labels,
186
            discretize=self.discretize,
187
        )
188
189
        label_map = subject[self.label_key][DATA]
190
191
        # Find out if we face a partial-volume image or a label map.
192
        # One-hot-encoded label map is considered as a partial-volume image
193
        all_discrete = label_map.eq(label_map.round()).all()
194
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
195
        is_discretized = all_discrete and same_num_dims
196
197
        if not is_discretized and self.discretize:
198
            # Take label with highest value in voxel
199
            max_label, label_map = label_map.max(dim=0, keepdim=True)
200
            # Remove values where all labels are 0 (i.e. missing labels)
201
            label_map[max_label == 0] = -1
202
            is_discretized = True
203
204
        if is_discretized:
205
            labels = label_map.unique().long().tolist()
206
            if -1 in labels:
207
                labels.remove(-1)
208
        else:
209
            labels = range(label_map.shape[0])
210
211
        # Raise error if mean and std are not defined for every label
212
        _check_mean_and_std_length(labels, self.mean, self.std)
213
214
        for label in labels:
215
            mean, std = self.get_params(label)
216
            arguments['mean'].append(mean)
217
            arguments['std'].append(std)
218
219
        transform = LabelsToImage(**arguments)
220
        transformed = transform(subject)
221
        return transformed
222
223
    def get_params(self, label: int) -> Tuple[float, float]:
224
        if self.mean is None:
225
            mean_range = self.default_mean
226
        else:
227
            mean_range = self.mean[label]
228
        if self.std is None:
229
            std_range = self.default_std
230
        else:
231
            std_range = self.std[label]
232
        mean = self.sample_uniform(*mean_range).item()
233
        std = self.sample_uniform(*std_range).item()
234
        return mean, std
235
236
237
class LabelsToImage(IntensityTransform):
238
    r"""Generate an image from a segmentation.
239
240
    Args:
241
        label_key: String designating the label map in the subject
242
            that will be used to generate the new image.
243
        used_labels: Sequence of integers designating the labels used
244
            to generate the new image. If categorical encoding is used,
245
            :py:attr:`label_channels` refers to the values of the
246
            categorical encoding. If one hot encoding or partial-volume
247
            label maps are used, :py:attr:`label_channels` refers to the
248
            channels of the label maps.
249
            Default uses all labels. Missing voxels will be filled with zero
250
            or with voxels from an already existing volume,
251
            see :py:attr:`image_key`.
252
        image_key: String designating the key to which the new volume will be
253
            saved. If this key corresponds to an already existing volume,
254
            missing voxels will be filled with the corresponding values
255
            in the original volume.
256
        mean: Sequence of means for each label.
257
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
258
            :py:attr:`mean` and :py:attr:`label_channels` must have the
259
            same length.
260
        std: Sequence of standard deviations for each label.
261
            If not ``None`` and :py:attr:`label_channels` is not ``None``,
262
            :py:attr:`std` and :py:attr:`label_channels` must have the
263
            same length.
264
        discretize: If ``True``, partial-volume label maps will be discretized.
265
            Does not have any effects if not using partial-volume label maps.
266
            Discretization is done taking the class of the highest value per
267
            voxel in the different partial-volume label maps using
268
            :py:func:`torch.argmax()` on the channel dimension (i.e. 0).
269
        seed: Seed for the random number generator.
270
        keys: See :py:class:`~torchio.transforms.Transform`.
271
272
    .. note:: It is recommended to blur the new images to make the result more
273
        realistic. See
274
        :py:class:`~torchio.transforms.augmentation.RandomBlur`.
275
    """
276
    def __init__(
277
            self,
278
            label_key: str,
279
            mean: Optional[Sequence[float]],
280
            std: Optional[Sequence[float]],
281
            image_key: str = 'image_from_labels',
282
            used_labels: Optional[Sequence[int]] = None,
283
            discretize: bool = False,
284
            keys: Optional[List[str]] = None,
285
            ):
286
        super().__init__(keys=keys)
287
        self.label_key = _parse_label_key(label_key)
288
        self.used_labels = _parse_used_labels(used_labels)
289
        self.mean, self.std = mean, std
290
        self.image_key = image_key
291
        self.discretize = discretize
292
        self.args_names = (
293
            'label_key',
294
            'mean',
295
            'std',
296
            'image_key',
297
            'used_labels',
298
            'discretize',
299
        )
300
301
    def apply_transform(self, subject: Subject) -> Subject:
302
        original_image = subject.get(self.image_key)
303
304
        label_map_image = subject[self.label_key]
305
        label_map = label_map_image.data
306
        affine = label_map_image.affine
307
308
        # Find out if we face a partial-volume image or a label map.
309
        # One-hot-encoded label map is considered as a partial-volume image
310
        all_discrete = label_map.eq(label_map.round()).all()
311
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
312
        is_discretized = all_discrete and same_num_dims
313
314
        if not is_discretized and self.discretize:
315
            # Take label with highest value in voxel
316
            max_label, label_map = label_map.max(dim=0, keepdim=True)
317
            # Remove values where all labels are 0 (i.e. missing labels)
318
            label_map[max_label == 0] = -1
319
            is_discretized = True
320
321
        tissues = torch.zeros(1, *label_map_image.spatial_shape).float()
322
        if is_discretized:
323
            labels = label_map.unique().long().tolist()
324
            if -1 in labels:
325
                labels.remove(-1)
326
        else:
327
            labels = range(label_map.shape[0])
328
329
        # Raise error if mean and std are not defined for every label
330
        _check_mean_and_std_length(labels, self.mean, self.std)
331
332
        for i, label in enumerate(labels):
333
            if self.used_labels is None or label in self.used_labels:
334
                mean = self.mean[i]
335
                std = self.std[i]
336
                if is_discretized:
337
                    mask = label_map == label
338
                else:
339
                    mask = label_map[label]
340
                tissues += self.generate_tissue(mask, mean, std)
341
342
            else:
343
                # Modify label map to easily compute background mask
344
                if is_discretized:
345
                    label_map[label_map == label] = -1
346
                else:
347
                    label_map[label] = 0
348
349
        final_image = ScalarImage(affine=affine, tensor=tissues)
350
351
        if original_image is not None:
352
            if is_discretized:
353
                bg_mask = label_map == -1
354
            else:
355
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
356
            final_image[DATA][bg_mask] = original_image[DATA][bg_mask]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
357
358
        subject.add_image(final_image, self.image_key)
359
        return subject
360
361
    @staticmethod
362
    def generate_tissue(
363
            data: TypeData,
364
            mean: float,
365
            std: float,
366
            ) -> TypeData:
367
        # Create the simulated tissue using a gaussian random variable
368
        data_shape = data.shape
369
        gaussian = torch.randn(data_shape) * std + mean
370
        return gaussian * data
371
372
373
def _parse_label_key(label_key: Optional[str]) -> Optional[str]:
374
    if label_key is not None and not isinstance(label_key, str):
375
        message = f'"label_key" must be a string or None, not {type(label_key)}'
376
        raise TypeError(message)
377
    return label_key
378
379
380
def _parse_used_labels(used_labels: Sequence[int]) -> Sequence[int]:
381
    if used_labels is None:
382
        return None
383
    check_sequence(used_labels, 'used_labels')
384
    for e in used_labels:
385
        if not isinstance(e, int):
386
            message = (
387
                'Items in "used_labels" must be integers,'
388
                f' but some are not: {used_labels}'
389
            )
390
            raise ValueError(message)
391
    return used_labels
392
393
394
def _check_mean_and_std_length(
395
        labels: Sequence[int],
396
        means: Optional[Sequence[TypeRangeFloat]],
397
        stds: Optional[Sequence[TypeRangeFloat]],
398
        ) -> None:
399
    num_labels = len(labels)
400
    if means is not None:
401
        num_means = len(means)
402
        message = (
403
            '"mean" must define a value for each label but length of "mean"'
404
            f' is {num_means} while {num_labels} labels were found'
405
        )
406
        if num_means != num_labels:
407
            raise RuntimeError(message)
408
    if stds is not None:
409
        num_stds = len(stds)
410
        message = (
411
            '"std" must define a value for each label but length of "std"'
412
            f' is {num_stds} while {num_labels} labels were found'
413
        )
414
        if num_stds != num_labels:
415
            raise RuntimeError(message)
416