Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

torchio.transforms.augmentation.intensity.random_labels_to_image   B

Complexity

Total Complexity 51

Size/Duplication

Total Lines 417
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 219
dl 0
loc 417
rs 7.92
c 0
b 0
f 0
wmc 51

9 Methods

Rating   Name   Duplication   Size   Complexity  
D LabelsToImage.apply_transform() 0 59 12
A RandomLabelsToImage.parse_mean_and_std() 0 16 5
A RandomLabelsToImage.get_params() 0 12 3
A LabelsToImage.__init__() 0 23 1
A RandomLabelsToImage.__init__() 0 22 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
A RandomLabelsToImage.parse_gaussian_parameters() 0 17 2
C RandomLabelsToImage.apply_transform() 0 52 10
A LabelsToImage.generate_tissue() 0 10 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_label_key() 0 5 3
A _check_mean_and_std_length() 0 22 5
A _parse_used_labels() 0 12 4

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Tuple, Optional, Sequence, List
2
3
import torch
4
5
from ....torchio import DATA, TypeData, TypeRangeFloat
6
from ....utils import check_sequence
7
from ....data.subject import Subject
8
from ....data.image import ScalarImage, LabelMap
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomLabelsToImage(RandomTransform, IntensityTransform):
14
    r"""Randomly generate an image from a segmentation.
15
16
    Based on the works by Billot et al.: `A Learning Strategy for
17
    Contrast-agnostic MRI
18
    Segmentation <http://proceedings.mlr.press/v121/billot20a.html>`_
19
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and
20
    Contrast <https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18>`_.
21
22
    Args:
23
        label_key: String designating the label map in the subject
24
            that will be used to generate the new image.
25
        used_labels: Sequence of integers designating the labels used
26
            to generate the new image. If categorical encoding is used,
27
            :attr:`label_channels` refers to the values of the
28
            categorical encoding. If one hot encoding or partial-volume
29
            label maps are used, :attr:`label_channels` refers to the
30
            channels of the label maps.
31
            Default uses all labels. Missing voxels will be filled with zero
32
            or with voxels from an already existing volume,
33
            see :attr:`image_key`.
34
        image_key: String designating the key to which the new volume will be
35
            saved. If this key corresponds to an already existing volume,
36
            missing voxels will be filled with the corresponding values
37
            in the original volume.
38
        mean: Sequence of means for each label.
39
            For each value :math:`v`, if a tuple :math:`(a, b)` is
40
            provided then :math:`v \sim \mathcal{U}(a, b)`.
41
            If ``None``, :attr:`default_mean` range will be used for every
42
            label.
43
            If not ``None`` and :attr:`label_channels` is not ``None``,
44
            :attr:`mean` and :attr:`label_channels` must have the
45
            same length.
46
        std: Sequence of standard deviations for each label.
47
            For each value :math:`v`, if a tuple :math:`(a, b)` is
48
            provided then :math:`v \sim \mathcal{U}(a, b)`.
49
            If ``None``, :attr:`default_std` range will be used for every
50
            label.
51
            If not ``None`` and :attr:`label_channels` is not ``None``,
52
            :attr:`std` and :attr:`label_channels` must have the
53
            same length.
54
        default_mean: Default mean range.
55
        default_std: Default standard deviation range.
56
        discretize: If ``True``, partial-volume label maps will be discretized.
57
            Does not have any effects if not using partial-volume label maps.
58
            Discretization is done taking the class of the highest value per
59
            voxel in the different partial-volume label maps using
60
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
61
        p: Probability that this transform will be applied.
62
        keys: See :class:`~torchio.transforms.Transform`.
63
64
    .. note:: It is recommended to blur the new images to make the result more
65
        realistic. See
66
        :class:`~torchio.transforms.augmentation.RandomBlur`.
67
68
    Example:
69
        >>> import torchio as tio
70
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
71
        >>> # Using the default parameters
72
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
73
        >>> # Using custom mean and std
74
        >>> transform = tio.RandomLabelsToImage(
75
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
76
        ... )
77
        >>> # Discretizing the partial volume maps and blurring the result
78
        >>> simulation_transform = tio.RandomLabelsToImage(
79
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
80
        ... )
81
        >>> blurring_transform = tio.RandomBlur(std=0.3)
82
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
83
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
84
        >>> # Filling holes of the simulated image with the original T1 image
85
        >>> rescale_transform = tio.RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
86
        >>> simulation_transform = tio.RandomLabelsToImage(
87
        ...     label_key='tissues',
88
        ...     image_key='t1',
89
        ...     used_labels=[0, 1]
90
        ... )
91
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
92
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image
93
    """
94
    def __init__(
95
            self,
96
            label_key: Optional[str] = None,
97
            used_labels: Optional[Sequence[int]] = None,
98
            image_key: str = 'image_from_labels',
99
            mean: Optional[Sequence[TypeRangeFloat]] = None,
100
            std: Optional[Sequence[TypeRangeFloat]] = None,
101
            default_mean: TypeRangeFloat = (0.1, 0.9),
102
            default_std: TypeRangeFloat = (0.01, 0.1),
103
            discretize: bool = False,
104
            p: float = 1,
105
            keys: Optional[Sequence[str]] = None,
106
            ):
107
        super().__init__(p=p, keys=keys)
108
        self.label_key = _parse_label_key(label_key)
109
        self.used_labels = _parse_used_labels(used_labels)
110
        self.mean, self.std = self.parse_mean_and_std(mean, std)
111
        self.default_mean = self.parse_gaussian_parameter(
112
            default_mean, 'default_mean')
113
        self.default_std = self.parse_gaussian_parameter(default_std, 'default_std')
114
        self.image_key = image_key
115
        self.discretize = discretize
116
117
    def parse_mean_and_std(
118
            self,
119
            mean: Sequence[TypeRangeFloat],
120
            std: Sequence[TypeRangeFloat]
121
            ) -> (List[TypeRangeFloat], List[TypeRangeFloat]):
122
        if mean is not None:
123
            mean = self.parse_gaussian_parameters(mean, 'mean')
124
        if std is not None:
125
            std = self.parse_gaussian_parameters(std, 'std')
126
        if mean is not None and std is not None:
127
            message = (
128
                'If both "mean" and "std" are defined they must have the same'
129
                'length'
130
            )
131
            assert len(mean) == len(std), message
132
        return mean, std
133
134
    def parse_gaussian_parameters(
135
            self,
136
            params: Sequence[TypeRangeFloat],
137
            name: str
138
            ) -> List[TypeRangeFloat]:
139
        check_sequence(params, name)
140
        params = [
141
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
142
            for i, p in enumerate(params)
143
        ]
144
        if self.used_labels is not None:
145
            message = (
146
                f'If both "{name}" and "used_labels" are defined, '
147
                f'they must have the same length'
148
            )
149
            assert len(params) == len(self.used_labels), message
150
        return params
151
152
    @staticmethod
153
    def parse_gaussian_parameter(
154
            nums_range: TypeRangeFloat,
155
            name: str,
156
            ) -> Tuple[float, float]:
157
        if isinstance(nums_range, (int, float)):
158
            return nums_range, nums_range
159
160
        if len(nums_range) != 2:
161
            raise ValueError(
162
                f'If {name} is a sequence,'
163
                f' it must be of len 2, not {nums_range}')
164
        min_value, max_value = nums_range
165
        if min_value > max_value:
166
            raise ValueError(
167
                f'If {name} is a sequence, the second value must be'
168
                f' equal or greater than the first, not {nums_range}')
169
        return min_value, max_value
170
171
    def apply_transform(self, subject: Subject) -> Subject:
172
        if self.label_key is None:
173
            iterable = subject.get_images_dict(intensity_only=False).items()
174
            for name, image in iterable:
175
                if isinstance(image, LabelMap):
176
                    self.label_key = name
177
                    break
178
            else:
179
                raise RuntimeError(f'No label maps found in subject: {subject}')
180
181
        arguments = dict(
182
            label_key=self.label_key,
183
            mean=[],
184
            std=[],
185
            image_key=self.image_key,
186
            used_labels=self.used_labels,
187
            discretize=self.discretize,
188
        )
189
190
        label_map = subject[self.label_key][DATA]
191
192
        # Find out if we face a partial-volume image or a label map.
193
        # One-hot-encoded label map is considered as a partial-volume image
194
        all_discrete = label_map.eq(label_map.round()).all()
195
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
196
        is_discretized = all_discrete and same_num_dims
197
198
        if not is_discretized and self.discretize:
199
            # Take label with highest value in voxel
200
            max_label, label_map = label_map.max(dim=0, keepdim=True)
201
            # Remove values where all labels are 0 (i.e. missing labels)
202
            label_map[max_label == 0] = -1
203
            is_discretized = True
204
205
        if is_discretized:
206
            labels = label_map.unique().long().tolist()
207
            if -1 in labels:
208
                labels.remove(-1)
209
        else:
210
            labels = range(label_map.shape[0])
211
212
        # Raise error if mean and std are not defined for every label
213
        _check_mean_and_std_length(labels, self.mean, self.std)
214
215
        for label in labels:
216
            mean, std = self.get_params(label)
217
            arguments['mean'].append(mean)
218
            arguments['std'].append(std)
219
220
        transform = LabelsToImage(**arguments)
221
        transformed = transform(subject)
222
        return transformed
223
224
    def get_params(self, label: int) -> Tuple[float, float]:
225
        if self.mean is None:
226
            mean_range = self.default_mean
227
        else:
228
            mean_range = self.mean[label]
229
        if self.std is None:
230
            std_range = self.default_std
231
        else:
232
            std_range = self.std[label]
233
        mean = self.sample_uniform(*mean_range).item()
234
        std = self.sample_uniform(*std_range).item()
235
        return mean, std
236
237
238
class LabelsToImage(IntensityTransform):
239
    r"""Generate an image from a segmentation.
240
241
    Args:
242
        label_key: String designating the label map in the subject
243
            that will be used to generate the new image.
244
        used_labels: Sequence of integers designating the labels used
245
            to generate the new image. If categorical encoding is used,
246
            :attr:`label_channels` refers to the values of the
247
            categorical encoding. If one hot encoding or partial-volume
248
            label maps are used, :attr:`label_channels` refers to the
249
            channels of the label maps.
250
            Default uses all labels. Missing voxels will be filled with zero
251
            or with voxels from an already existing volume,
252
            see :attr:`image_key`.
253
        image_key: String designating the key to which the new volume will be
254
            saved. If this key corresponds to an already existing volume,
255
            missing voxels will be filled with the corresponding values
256
            in the original volume.
257
        mean: Sequence of means for each label.
258
            If not ``None`` and :attr:`label_channels` is not ``None``,
259
            :attr:`mean` and :attr:`label_channels` must have the
260
            same length.
261
        std: Sequence of standard deviations for each label.
262
            If not ``None`` and :attr:`label_channels` is not ``None``,
263
            :attr:`std` and :attr:`label_channels` must have the
264
            same length.
265
        discretize: If ``True``, partial-volume label maps will be discretized.
266
            Does not have any effects if not using partial-volume label maps.
267
            Discretization is done taking the class of the highest value per
268
            voxel in the different partial-volume label maps using
269
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
270
        seed: Seed for the random number generator.
271
        keys: See :class:`~torchio.transforms.Transform`.
272
273
    .. note:: It is recommended to blur the new images to make the result more
274
        realistic. See
275
        :class:`~torchio.transforms.augmentation.RandomBlur`.
276
    """
277
    def __init__(
278
            self,
279
            label_key: str,
280
            mean: Optional[Sequence[float]],
281
            std: Optional[Sequence[float]],
282
            image_key: str = 'image_from_labels',
283
            used_labels: Optional[Sequence[int]] = None,
284
            discretize: bool = False,
285
            keys: Optional[List[str]] = None,
286
            ):
287
        super().__init__(keys=keys)
288
        self.label_key = _parse_label_key(label_key)
289
        self.used_labels = _parse_used_labels(used_labels)
290
        self.mean, self.std = mean, std
291
        self.image_key = image_key
292
        self.discretize = discretize
293
        self.args_names = (
294
            'label_key',
295
            'mean',
296
            'std',
297
            'image_key',
298
            'used_labels',
299
            'discretize',
300
        )
301
302
    def apply_transform(self, subject: Subject) -> Subject:
303
        original_image = subject.get(self.image_key)
304
305
        label_map_image = subject[self.label_key]
306
        label_map = label_map_image.data
307
        affine = label_map_image.affine
308
309
        # Find out if we face a partial-volume image or a label map.
310
        # One-hot-encoded label map is considered as a partial-volume image
311
        all_discrete = label_map.eq(label_map.round()).all()
312
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
313
        is_discretized = all_discrete and same_num_dims
314
315
        if not is_discretized and self.discretize:
316
            # Take label with highest value in voxel
317
            max_label, label_map = label_map.max(dim=0, keepdim=True)
318
            # Remove values where all labels are 0 (i.e. missing labels)
319
            label_map[max_label == 0] = -1
320
            is_discretized = True
321
322
        tissues = torch.zeros(1, *label_map_image.spatial_shape).float()
323
        if is_discretized:
324
            labels = label_map.unique().long().tolist()
325
            if -1 in labels:
326
                labels.remove(-1)
327
        else:
328
            labels = range(label_map.shape[0])
329
330
        # Raise error if mean and std are not defined for every label
331
        _check_mean_and_std_length(labels, self.mean, self.std)
332
333
        for i, label in enumerate(labels):
334
            if self.used_labels is None or label in self.used_labels:
335
                mean = self.mean[i]
336
                std = self.std[i]
337
                if is_discretized:
338
                    mask = label_map == label
339
                else:
340
                    mask = label_map[label]
341
                tissues += self.generate_tissue(mask, mean, std)
342
343
            else:
344
                # Modify label map to easily compute background mask
345
                if is_discretized:
346
                    label_map[label_map == label] = -1
347
                else:
348
                    label_map[label] = 0
349
350
        final_image = ScalarImage(affine=affine, tensor=tissues)
351
352
        if original_image is not None:
353
            if is_discretized:
354
                bg_mask = label_map == -1
355
            else:
356
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
357
            final_image[DATA][bg_mask] = original_image[DATA][bg_mask]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
358
359
        subject.add_image(final_image, self.image_key)
360
        return subject
361
362
    @staticmethod
363
    def generate_tissue(
364
            data: TypeData,
365
            mean: float,
366
            std: float,
367
            ) -> TypeData:
368
        # Create the simulated tissue using a gaussian random variable
369
        data_shape = data.shape
370
        gaussian = torch.randn(data_shape) * std + mean
371
        return gaussian * data
372
373
374
def _parse_label_key(label_key: Optional[str]) -> Optional[str]:
375
    if label_key is not None and not isinstance(label_key, str):
376
        message = f'"label_key" must be a string or None, not {type(label_key)}'
377
        raise TypeError(message)
378
    return label_key
379
380
381
def _parse_used_labels(used_labels: Sequence[int]) -> Sequence[int]:
382
    if used_labels is None:
383
        return None
384
    check_sequence(used_labels, 'used_labels')
385
    for e in used_labels:
386
        if not isinstance(e, int):
387
            message = (
388
                'Items in "used_labels" must be integers,'
389
                f' but some are not: {used_labels}'
390
            )
391
            raise ValueError(message)
392
    return used_labels
393
394
395
def _check_mean_and_std_length(
396
        labels: Sequence[int],
397
        means: Optional[Sequence[TypeRangeFloat]],
398
        stds: Optional[Sequence[TypeRangeFloat]],
399
        ) -> None:
400
    num_labels = len(labels)
401
    if means is not None:
402
        num_means = len(means)
403
        message = (
404
            '"mean" must define a value for each label but length of "mean"'
405
            f' is {num_means} while {num_labels} labels were found'
406
        )
407
        if num_means != num_labels:
408
            raise RuntimeError(message)
409
    if stds is not None:
410
        num_stds = len(stds)
411
        message = (
412
            '"std" must define a value for each label but length of "std"'
413
            f' is {num_stds} while {num_labels} labels were found'
414
        )
415
        if num_stds != num_labels:
416
            raise RuntimeError(message)
417