LabelsToImage.generate_tissue()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 9
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
from typing import Tuple, Optional, Sequence, List
2
3
import torch
4
5
from ....utils import check_sequence
6
from ....data.subject import Subject
7
from ....typing import TypeData, TypeRangeFloat
8
from ....data.image import ScalarImage, LabelMap
9
from ... import IntensityTransform
10
from .. import RandomTransform
11
12
13
class RandomLabelsToImage(RandomTransform, IntensityTransform):
14
    r"""Randomly generate an image from a segmentation.
15
16
    Based on the works by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_
17
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast <https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18>`__.
18
19
    .. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html
20
21
    Args:
22
        label_key: String designating the label map in the subject
23
            that will be used to generate the new image.
24
        used_labels: Sequence of integers designating the labels used
25
            to generate the new image. If categorical encoding is used,
26
            :attr:`label_channels` refers to the values of the
27
            categorical encoding. If one hot encoding or partial-volume
28
            label maps are used, :attr:`label_channels` refers to the
29
            channels of the label maps.
30
            Default uses all labels. Missing voxels will be filled with zero
31
            or with voxels from an already existing volume,
32
            see :attr:`image_key`.
33
        image_key: String designating the key to which the new volume will be
34
            saved. If this key corresponds to an already existing volume,
35
            missing voxels will be filled with the corresponding values
36
            in the original volume.
37
        mean: Sequence of means for each label.
38
            For each value :math:`v`, if a tuple :math:`(a, b)` is
39
            provided then :math:`v \sim \mathcal{U}(a, b)`.
40
            If ``None``, :attr:`default_mean` range will be used for every
41
            label.
42
            If not ``None`` and :attr:`label_channels` is not ``None``,
43
            :attr:`mean` and :attr:`label_channels` must have the
44
            same length.
45
        std: Sequence of standard deviations for each label.
46
            For each value :math:`v`, if a tuple :math:`(a, b)` is
47
            provided then :math:`v \sim \mathcal{U}(a, b)`.
48
            If ``None``, :attr:`default_std` range will be used for every
49
            label.
50
            If not ``None`` and :attr:`label_channels` is not ``None``,
51
            :attr:`std` and :attr:`label_channels` must have the
52
            same length.
53
        default_mean: Default mean range.
54
        default_std: Default standard deviation range.
55
        discretize: If ``True``, partial-volume label maps will be discretized.
56
            Does not have any effects if not using partial-volume label maps.
57
            Discretization is done taking the class of the highest value per
58
            voxel in the different partial-volume label maps using
59
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
60
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
61
            keyword arguments.
62
63
    .. tip:: It is recommended to blur the new images to make the result more
64
        realistic. See
65
        :class:`~torchio.transforms.augmentation.RandomBlur`.
66
67
    Example:
68
        >>> import torchio as tio
69
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
70
        >>> # Using the default parameters
71
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
72
        >>> # Using custom mean and std
73
        >>> transform = tio.RandomLabelsToImage(
74
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
75
        ... )
76
        >>> # Discretizing the partial volume maps and blurring the result
77
        >>> simulation_transform = tio.RandomLabelsToImage(
78
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
79
        ... )
80
        >>> blurring_transform = tio.RandomBlur(std=0.3)
81
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
82
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
83
        >>> # Filling holes of the simulated image with the original T1 image
84
        >>> rescale_transform = tio.RescaleIntensity(
85
        ...     out_min_max=(0, 1), percentiles=(1, 99))   # Rescale intensity before filling holes
86
        >>> simulation_transform = tio.RandomLabelsToImage(
87
        ...     label_key='tissues',
88
        ...     image_key='t1',
89
        ...     used_labels=[0, 1]
90
        ... )
91
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
92
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image
93
    """  # noqa: E501
94
    def __init__(
95
            self,
96
            label_key: Optional[str] = None,
97
            used_labels: Optional[Sequence[int]] = None,
98
            image_key: str = 'image_from_labels',
99
            mean: Optional[Sequence[TypeRangeFloat]] = None,
100
            std: Optional[Sequence[TypeRangeFloat]] = None,
101
            default_mean: TypeRangeFloat = (0.1, 0.9),
102
            default_std: TypeRangeFloat = (0.01, 0.1),
103
            discretize: bool = False,
104
            **kwargs
105
            ):
106
        super().__init__(**kwargs)
107
        self.label_key = _parse_label_key(label_key)
108
        self.used_labels = _parse_used_labels(used_labels)
109
        self.mean, self.std = self.parse_mean_and_std(mean, std)
110
        self.default_mean = self.parse_gaussian_parameter(
111
            default_mean, 'default_mean')
112
        self.default_std = self.parse_gaussian_parameter(
113
            default_std,
114
            'default_std',
115
        )
116
        self.image_key = image_key
117
        self.discretize = discretize
118
119
    def parse_mean_and_std(
120
            self,
121
            mean: Sequence[TypeRangeFloat],
122
            std: Sequence[TypeRangeFloat]
123
            ) -> (List[TypeRangeFloat], List[TypeRangeFloat]):
124
        if mean is not None:
125
            mean = self.parse_gaussian_parameters(mean, 'mean')
126
        if std is not None:
127
            std = self.parse_gaussian_parameters(std, 'std')
128
        if mean is not None and std is not None:
129
            message = (
130
                'If both "mean" and "std" are defined they must have the same'
131
                'length'
132
            )
133
            assert len(mean) == len(std), message
134
        return mean, std
135
136
    def parse_gaussian_parameters(
137
            self,
138
            params: Sequence[TypeRangeFloat],
139
            name: str
140
            ) -> List[TypeRangeFloat]:
141
        check_sequence(params, name)
142
        params = [
143
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
144
            for i, p in enumerate(params)
145
        ]
146
        if self.used_labels is not None:
147
            message = (
148
                f'If both "{name}" and "used_labels" are defined, '
149
                f'they must have the same length'
150
            )
151
            assert len(params) == len(self.used_labels), message
152
        return params
153
154
    @staticmethod
155
    def parse_gaussian_parameter(
156
            nums_range: TypeRangeFloat,
157
            name: str,
158
            ) -> Tuple[float, float]:
159
        if isinstance(nums_range, (int, float)):
160
            return nums_range, nums_range
161
162
        if len(nums_range) != 2:
163
            raise ValueError(
164
                f'If {name} is a sequence,'
165
                f' it must be of len 2, not {nums_range}')
166
        min_value, max_value = nums_range
167
        if min_value > max_value:
168
            raise ValueError(
169
                f'If {name} is a sequence, the second value must be'
170
                f' equal or greater than the first, not {nums_range}')
171
        return min_value, max_value
172
173
    def apply_transform(self, subject: Subject) -> Subject:
174
        if self.label_key is None:
175
            iterable = subject.get_images_dict(intensity_only=False).items()
176
            for name, image in iterable:
177
                if isinstance(image, LabelMap):
178
                    self.label_key = name
179
                    break
180
            else:
181
                message = f'No label maps found in subject: {subject}'
182
                raise RuntimeError(message)
183
184
        arguments = {
185
            'label_key': self.label_key,
186
            'mean': [],
187
            'std': [],
188
            'image_key': self.image_key,
189
            'used_labels': self.used_labels,
190
            'discretize': self.discretize,
191
        }
192
193
        label_map = subject[self.label_key].data
194
195
        # Find out if we face a partial-volume image or a label map.
196
        # One-hot-encoded label map is considered as a partial-volume image
197
        all_discrete = label_map.eq(label_map.float().round()).all()
198
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
199
        is_discretized = all_discrete and same_num_dims
200
201
        if not is_discretized and self.discretize:
202
            # Take label with highest value in voxel
203
            max_label, label_map = label_map.max(dim=0, keepdim=True)
204
            # Remove values where all labels are 0 (i.e. missing labels)
205
            label_map[max_label == 0] = -1
206
            is_discretized = True
207
208
        if is_discretized:
209
            labels = label_map.unique().long().tolist()
210
            if -1 in labels:
211
                labels.remove(-1)
212
        else:
213
            labels = range(label_map.shape[0])
214
215
        # Raise error if mean and std are not defined for every label
216
        _check_mean_and_std_length(labels, self.mean, self.std)
217
218
        for label in labels:
219
            mean, std = self.get_params(label)
220
            arguments['mean'].append(mean)
221
            arguments['std'].append(std)
222
223
        transform = LabelsToImage(**self.add_include_exclude(arguments))
224
        transformed = transform(subject)
225
        return transformed
226
227
    def get_params(self, label: int) -> Tuple[float, float]:
228
        if self.mean is None:
229
            mean_range = self.default_mean
230
        else:
231
            mean_range = self.mean[label]
232
        if self.std is None:
233
            std_range = self.default_std
234
        else:
235
            std_range = self.std[label]
236
        mean = self.sample_uniform(*mean_range).item()
237
        std = self.sample_uniform(*std_range).item()
238
        return mean, std
239
240
241
class LabelsToImage(IntensityTransform):
242
    r"""Generate an image from a segmentation.
243
244
    Args:
245
        label_key: String designating the label map in the subject
246
            that will be used to generate the new image.
247
        used_labels: Sequence of integers designating the labels used
248
            to generate the new image. If categorical encoding is used,
249
            :attr:`label_channels` refers to the values of the
250
            categorical encoding. If one hot encoding or partial-volume
251
            label maps are used, :attr:`label_channels` refers to the
252
            channels of the label maps.
253
            Default uses all labels. Missing voxels will be filled with zero
254
            or with voxels from an already existing volume,
255
            see :attr:`image_key`.
256
        image_key: String designating the key to which the new volume will be
257
            saved. If this key corresponds to an already existing volume,
258
            missing voxels will be filled with the corresponding values
259
            in the original volume.
260
        mean: Sequence of means for each label.
261
            If not ``None`` and :attr:`label_channels` is not ``None``,
262
            :attr:`mean` and :attr:`label_channels` must have the
263
            same length.
264
        std: Sequence of standard deviations for each label.
265
            If not ``None`` and :attr:`label_channels` is not ``None``,
266
            :attr:`std` and :attr:`label_channels` must have the
267
            same length.
268
        discretize: If ``True``, partial-volume label maps will be discretized.
269
            Does not have any effects if not using partial-volume label maps.
270
            Discretization is done taking the class of the highest value per
271
            voxel in the different partial-volume label maps using
272
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
273
        seed: Seed for the random number generator.
274
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
275
            keyword arguments.
276
277
    .. note:: It is recommended to blur the new images to make the result more
278
        realistic. See
279
        :class:`~torchio.transforms.augmentation.RandomBlur`.
280
    """
281
    def __init__(
282
            self,
283
            label_key: str,
284
            mean: Optional[Sequence[float]],
285
            std: Optional[Sequence[float]],
286
            image_key: str = 'image_from_labels',
287
            used_labels: Optional[Sequence[int]] = None,
288
            discretize: bool = False,
289
            **kwargs
290
            ):
291
        super().__init__(**kwargs)
292
        self.label_key = _parse_label_key(label_key)
293
        self.used_labels = _parse_used_labels(used_labels)
294
        self.mean, self.std = mean, std
295
        self.image_key = image_key
296
        self.discretize = discretize
297
        self.args_names = (
298
            'label_key',
299
            'mean',
300
            'std',
301
            'image_key',
302
            'used_labels',
303
            'discretize',
304
        )
305
306
    def apply_transform(self, subject: Subject) -> Subject:
307
        original_image = subject.get(self.image_key)
308
309
        label_map_image = subject[self.label_key]
310
        label_map = label_map_image.data
311
        affine = label_map_image.affine
312
313
        # Find out if we face a partial-volume image or a label map.
314
        # One-hot-encoded label map is considered as a partial-volume image
315
        all_discrete = label_map.eq(label_map.float().round()).all()
316
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
317
        is_discretized = all_discrete and same_num_dims
318
319
        if not is_discretized and self.discretize:
320
            # Take label with highest value in voxel
321
            max_label, label_map = label_map.max(dim=0, keepdim=True)
322
            # Remove values where all labels are 0 (i.e. missing labels)
323
            label_map[max_label == 0] = -1
324
            is_discretized = True
325
326
        tissues = torch.zeros(1, *label_map_image.spatial_shape).float()
327
        if is_discretized:
328
            labels = label_map.unique().long().tolist()
329
            if -1 in labels:
330
                labels.remove(-1)
331
        else:
332
            labels = range(label_map.shape[0])
333
334
        # Raise error if mean and std are not defined for every label
335
        _check_mean_and_std_length(labels, self.mean, self.std)
336
337
        for i, label in enumerate(labels):
338
            if self.used_labels is None or label in self.used_labels:
339
                mean = self.mean[i]
340
                std = self.std[i]
341
                if is_discretized:
342
                    mask = label_map == label
343
                else:
344
                    mask = label_map[label]
345
                tissues += self.generate_tissue(mask, mean, std)
346
347
            else:
348
                # Modify label map to easily compute background mask
349
                if is_discretized:
350
                    label_map[label_map == label] = -1
351
                else:
352
                    label_map[label] = 0
353
354
        final_image = ScalarImage(affine=affine, tensor=tissues)
355
356
        if original_image is not None:
357
            if is_discretized:
358
                bg_mask = label_map == -1
359
            else:
360
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
361
            final_image.data[bg_mask] = original_image.data[bg_mask].float()
362
363
        subject.add_image(final_image, self.image_key)
364
        return subject
365
366
    @staticmethod
367
    def generate_tissue(
368
            data: TypeData,
369
            mean: float,
370
            std: float,
371
            ) -> TypeData:
372
        # Create the simulated tissue using a gaussian random variable
373
        gaussian = torch.randn(data.shape) * std + mean
374
        return gaussian * data
375
376
377
def _parse_label_key(label_key: Optional[str]) -> Optional[str]:
378
    if label_key is not None and not isinstance(label_key, str):
379
        message = (
380
            f'"label_key" must be a string or None, not {type(label_key)}')
381
        raise TypeError(message)
382
    return label_key
383
384
385
def _parse_used_labels(used_labels: Sequence[int]) -> Sequence[int]:
386
    if used_labels is None:
387
        return None
388
    check_sequence(used_labels, 'used_labels')
389
    for e in used_labels:
390
        if not isinstance(e, int):
391
            message = (
392
                'Items in "used_labels" must be integers,'
393
                f' but some are not: {used_labels}'
394
            )
395
            raise ValueError(message)
396
    return used_labels
397
398
399
def _check_mean_and_std_length(
400
        labels: Sequence[int],
401
        means: Optional[Sequence[TypeRangeFloat]],
402
        stds: Optional[Sequence[TypeRangeFloat]],
403
        ) -> None:
404
    num_labels = len(labels)
405
    if means is not None:
406
        num_means = len(means)
407
        message = (
408
            '"mean" must define a value for each label but length of "mean"'
409
            f' is {num_means} while {num_labels} labels were found'
410
        )
411
        if num_means != num_labels:
412
            raise RuntimeError(message)
413
    if stds is not None:
414
        num_stds = len(stds)
415
        message = (
416
            '"std" must define a value for each label but length of "std"'
417
            f' is {num_stds} while {num_labels} labels were found'
418
        )
419
        if num_stds != num_labels:
420
            raise RuntimeError(message)
421