1
|
|
|
from typing import Tuple, Optional, Sequence, List |
2
|
|
|
|
3
|
|
|
import torch |
4
|
|
|
|
5
|
|
|
from ....utils import check_sequence |
6
|
|
|
from ....data.subject import Subject |
7
|
|
|
from ....typing import TypeData, TypeRangeFloat |
8
|
|
|
from ....data.image import ScalarImage, LabelMap |
9
|
|
|
from ... import IntensityTransform |
10
|
|
|
from .. import RandomTransform |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
class RandomLabelsToImage(RandomTransform, IntensityTransform): |
14
|
|
|
r"""Randomly generate an image from a segmentation. |
15
|
|
|
|
16
|
|
|
Based on the works by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_ |
17
|
|
|
and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast <https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18>`__. |
18
|
|
|
|
19
|
|
|
.. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html |
20
|
|
|
|
21
|
|
|
Args: |
22
|
|
|
label_key: String designating the label map in the subject |
23
|
|
|
that will be used to generate the new image. |
24
|
|
|
used_labels: Sequence of integers designating the labels used |
25
|
|
|
to generate the new image. If categorical encoding is used, |
26
|
|
|
:attr:`label_channels` refers to the values of the |
27
|
|
|
categorical encoding. If one hot encoding or partial-volume |
28
|
|
|
label maps are used, :attr:`label_channels` refers to the |
29
|
|
|
channels of the label maps. |
30
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
31
|
|
|
or with voxels from an already existing volume, |
32
|
|
|
see :attr:`image_key`. |
33
|
|
|
image_key: String designating the key to which the new volume will be |
34
|
|
|
saved. If this key corresponds to an already existing volume, |
35
|
|
|
missing voxels will be filled with the corresponding values |
36
|
|
|
in the original volume. |
37
|
|
|
mean: Sequence of means for each label. |
38
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
39
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
40
|
|
|
If ``None``, :attr:`default_mean` range will be used for every |
41
|
|
|
label. |
42
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
43
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
44
|
|
|
same length. |
45
|
|
|
std: Sequence of standard deviations for each label. |
46
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
47
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
48
|
|
|
If ``None``, :attr:`default_std` range will be used for every |
49
|
|
|
label. |
50
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
51
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
52
|
|
|
same length. |
53
|
|
|
default_mean: Default mean range. |
54
|
|
|
default_std: Default standard deviation range. |
55
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
56
|
|
|
Does not have any effects if not using partial-volume label maps. |
57
|
|
|
Discretization is done taking the class of the highest value per |
58
|
|
|
voxel in the different partial-volume label maps using |
59
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
60
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
61
|
|
|
keyword arguments. |
62
|
|
|
|
63
|
|
|
.. tip:: It is recommended to blur the new images to make the result more |
64
|
|
|
realistic. See |
65
|
|
|
:class:`~torchio.transforms.augmentation.RandomBlur`. |
66
|
|
|
|
67
|
|
|
Example: |
68
|
|
|
>>> import torchio as tio |
69
|
|
|
>>> subject = tio.datasets.ICBM2009CNonlinearSymmetric() |
70
|
|
|
>>> # Using the default parameters |
71
|
|
|
>>> transform = tio.RandomLabelsToImage(label_key='tissues') |
72
|
|
|
>>> # Using custom mean and std |
73
|
|
|
>>> transform = tio.RandomLabelsToImage( |
74
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0] |
75
|
|
|
... ) |
76
|
|
|
>>> # Discretizing the partial volume maps and blurring the result |
77
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
78
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True |
79
|
|
|
... ) |
80
|
|
|
>>> blurring_transform = tio.RandomBlur(std=0.3) |
81
|
|
|
>>> transform = tio.Compose([simulation_transform, blurring_transform]) |
82
|
|
|
>>> transformed = transform(subject) # subject has a new key 'image_from_labels' with the simulated image |
83
|
|
|
>>> # Filling holes of the simulated image with the original T1 image |
84
|
|
|
>>> rescale_transform = tio.RescaleIntensity( |
85
|
|
|
... out_min_max=(0, 1), percentiles=(1, 99)) # Rescale intensity before filling holes |
86
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
87
|
|
|
... label_key='tissues', |
88
|
|
|
... image_key='t1', |
89
|
|
|
... used_labels=[0, 1] |
90
|
|
|
... ) |
91
|
|
|
>>> transform = tio.Compose([rescale_transform, simulation_transform]) |
92
|
|
|
>>> transformed = transform(subject) # subject's key 't1' has been replaced with the simulated image |
93
|
|
|
""" # noqa: E501 |
94
|
|
|
def __init__( |
95
|
|
|
self, |
96
|
|
|
label_key: Optional[str] = None, |
97
|
|
|
used_labels: Optional[Sequence[int]] = None, |
98
|
|
|
image_key: str = 'image_from_labels', |
99
|
|
|
mean: Optional[Sequence[TypeRangeFloat]] = None, |
100
|
|
|
std: Optional[Sequence[TypeRangeFloat]] = None, |
101
|
|
|
default_mean: TypeRangeFloat = (0.1, 0.9), |
102
|
|
|
default_std: TypeRangeFloat = (0.01, 0.1), |
103
|
|
|
discretize: bool = False, |
104
|
|
|
**kwargs |
105
|
|
|
): |
106
|
|
|
super().__init__(**kwargs) |
107
|
|
|
self.label_key = _parse_label_key(label_key) |
108
|
|
|
self.used_labels = _parse_used_labels(used_labels) |
109
|
|
|
self.mean, self.std = self.parse_mean_and_std(mean, std) |
110
|
|
|
self.default_mean = self.parse_gaussian_parameter( |
111
|
|
|
default_mean, 'default_mean') |
112
|
|
|
self.default_std = self.parse_gaussian_parameter( |
113
|
|
|
default_std, |
114
|
|
|
'default_std', |
115
|
|
|
) |
116
|
|
|
self.image_key = image_key |
117
|
|
|
self.discretize = discretize |
118
|
|
|
|
119
|
|
|
def parse_mean_and_std( |
120
|
|
|
self, |
121
|
|
|
mean: Sequence[TypeRangeFloat], |
122
|
|
|
std: Sequence[TypeRangeFloat] |
123
|
|
|
) -> (List[TypeRangeFloat], List[TypeRangeFloat]): |
124
|
|
|
if mean is not None: |
125
|
|
|
mean = self.parse_gaussian_parameters(mean, 'mean') |
126
|
|
|
if std is not None: |
127
|
|
|
std = self.parse_gaussian_parameters(std, 'std') |
128
|
|
|
if mean is not None and std is not None: |
129
|
|
|
message = ( |
130
|
|
|
'If both "mean" and "std" are defined they must have the same' |
131
|
|
|
'length' |
132
|
|
|
) |
133
|
|
|
assert len(mean) == len(std), message |
134
|
|
|
return mean, std |
135
|
|
|
|
136
|
|
|
def parse_gaussian_parameters( |
137
|
|
|
self, |
138
|
|
|
params: Sequence[TypeRangeFloat], |
139
|
|
|
name: str |
140
|
|
|
) -> List[TypeRangeFloat]: |
141
|
|
|
check_sequence(params, name) |
142
|
|
|
params = [ |
143
|
|
|
self.parse_gaussian_parameter(p, f'{name}[{i}]') |
144
|
|
|
for i, p in enumerate(params) |
145
|
|
|
] |
146
|
|
|
if self.used_labels is not None: |
147
|
|
|
message = ( |
148
|
|
|
f'If both "{name}" and "used_labels" are defined, ' |
149
|
|
|
f'they must have the same length' |
150
|
|
|
) |
151
|
|
|
assert len(params) == len(self.used_labels), message |
152
|
|
|
return params |
153
|
|
|
|
154
|
|
|
@staticmethod |
155
|
|
|
def parse_gaussian_parameter( |
156
|
|
|
nums_range: TypeRangeFloat, |
157
|
|
|
name: str, |
158
|
|
|
) -> Tuple[float, float]: |
159
|
|
|
if isinstance(nums_range, (int, float)): |
160
|
|
|
return nums_range, nums_range |
161
|
|
|
|
162
|
|
|
if len(nums_range) != 2: |
163
|
|
|
raise ValueError( |
164
|
|
|
f'If {name} is a sequence,' |
165
|
|
|
f' it must be of len 2, not {nums_range}') |
166
|
|
|
min_value, max_value = nums_range |
167
|
|
|
if min_value > max_value: |
168
|
|
|
raise ValueError( |
169
|
|
|
f'If {name} is a sequence, the second value must be' |
170
|
|
|
f' equal or greater than the first, not {nums_range}') |
171
|
|
|
return min_value, max_value |
172
|
|
|
|
173
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
174
|
|
|
if self.label_key is None: |
175
|
|
|
iterable = subject.get_images_dict(intensity_only=False).items() |
176
|
|
|
for name, image in iterable: |
177
|
|
|
if isinstance(image, LabelMap): |
178
|
|
|
self.label_key = name |
179
|
|
|
break |
180
|
|
|
else: |
181
|
|
|
message = f'No label maps found in subject: {subject}' |
182
|
|
|
raise RuntimeError(message) |
183
|
|
|
|
184
|
|
|
arguments = { |
185
|
|
|
'label_key': self.label_key, |
186
|
|
|
'mean': [], |
187
|
|
|
'std': [], |
188
|
|
|
'image_key': self.image_key, |
189
|
|
|
'used_labels': self.used_labels, |
190
|
|
|
'discretize': self.discretize, |
191
|
|
|
} |
192
|
|
|
|
193
|
|
|
label_map = subject[self.label_key].data |
194
|
|
|
|
195
|
|
|
# Find out if we face a partial-volume image or a label map. |
196
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
197
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
198
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
199
|
|
|
is_discretized = all_discrete and same_num_dims |
200
|
|
|
|
201
|
|
|
if not is_discretized and self.discretize: |
202
|
|
|
# Take label with highest value in voxel |
203
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
204
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
205
|
|
|
label_map[max_label == 0] = -1 |
206
|
|
|
is_discretized = True |
207
|
|
|
|
208
|
|
|
if is_discretized: |
209
|
|
|
labels = label_map.unique().long().tolist() |
210
|
|
|
if -1 in labels: |
211
|
|
|
labels.remove(-1) |
212
|
|
|
else: |
213
|
|
|
labels = range(label_map.shape[0]) |
214
|
|
|
|
215
|
|
|
# Raise error if mean and std are not defined for every label |
216
|
|
|
_check_mean_and_std_length(labels, self.mean, self.std) |
217
|
|
|
|
218
|
|
|
for label in labels: |
219
|
|
|
mean, std = self.get_params(label) |
220
|
|
|
arguments['mean'].append(mean) |
221
|
|
|
arguments['std'].append(std) |
222
|
|
|
|
223
|
|
|
transform = LabelsToImage(**self.add_include_exclude(arguments)) |
224
|
|
|
transformed = transform(subject) |
225
|
|
|
return transformed |
226
|
|
|
|
227
|
|
|
def get_params(self, label: int) -> Tuple[float, float]: |
228
|
|
|
if self.mean is None: |
229
|
|
|
mean_range = self.default_mean |
230
|
|
|
else: |
231
|
|
|
mean_range = self.mean[label] |
232
|
|
|
if self.std is None: |
233
|
|
|
std_range = self.default_std |
234
|
|
|
else: |
235
|
|
|
std_range = self.std[label] |
236
|
|
|
mean = self.sample_uniform(*mean_range).item() |
237
|
|
|
std = self.sample_uniform(*std_range).item() |
238
|
|
|
return mean, std |
239
|
|
|
|
240
|
|
|
|
241
|
|
|
class LabelsToImage(IntensityTransform): |
242
|
|
|
r"""Generate an image from a segmentation. |
243
|
|
|
|
244
|
|
|
Args: |
245
|
|
|
label_key: String designating the label map in the subject |
246
|
|
|
that will be used to generate the new image. |
247
|
|
|
used_labels: Sequence of integers designating the labels used |
248
|
|
|
to generate the new image. If categorical encoding is used, |
249
|
|
|
:attr:`label_channels` refers to the values of the |
250
|
|
|
categorical encoding. If one hot encoding or partial-volume |
251
|
|
|
label maps are used, :attr:`label_channels` refers to the |
252
|
|
|
channels of the label maps. |
253
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
254
|
|
|
or with voxels from an already existing volume, |
255
|
|
|
see :attr:`image_key`. |
256
|
|
|
image_key: String designating the key to which the new volume will be |
257
|
|
|
saved. If this key corresponds to an already existing volume, |
258
|
|
|
missing voxels will be filled with the corresponding values |
259
|
|
|
in the original volume. |
260
|
|
|
mean: Sequence of means for each label. |
261
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
262
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
263
|
|
|
same length. |
264
|
|
|
std: Sequence of standard deviations for each label. |
265
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
266
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
267
|
|
|
same length. |
268
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
269
|
|
|
Does not have any effects if not using partial-volume label maps. |
270
|
|
|
Discretization is done taking the class of the highest value per |
271
|
|
|
voxel in the different partial-volume label maps using |
272
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
273
|
|
|
seed: Seed for the random number generator. |
274
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
275
|
|
|
keyword arguments. |
276
|
|
|
|
277
|
|
|
.. note:: It is recommended to blur the new images to make the result more |
278
|
|
|
realistic. See |
279
|
|
|
:class:`~torchio.transforms.augmentation.RandomBlur`. |
280
|
|
|
""" |
281
|
|
|
def __init__( |
282
|
|
|
self, |
283
|
|
|
label_key: str, |
284
|
|
|
mean: Optional[Sequence[float]], |
285
|
|
|
std: Optional[Sequence[float]], |
286
|
|
|
image_key: str = 'image_from_labels', |
287
|
|
|
used_labels: Optional[Sequence[int]] = None, |
288
|
|
|
discretize: bool = False, |
289
|
|
|
**kwargs |
290
|
|
|
): |
291
|
|
|
super().__init__(**kwargs) |
292
|
|
|
self.label_key = _parse_label_key(label_key) |
293
|
|
|
self.used_labels = _parse_used_labels(used_labels) |
294
|
|
|
self.mean, self.std = mean, std |
295
|
|
|
self.image_key = image_key |
296
|
|
|
self.discretize = discretize |
297
|
|
|
self.args_names = ( |
298
|
|
|
'label_key', |
299
|
|
|
'mean', |
300
|
|
|
'std', |
301
|
|
|
'image_key', |
302
|
|
|
'used_labels', |
303
|
|
|
'discretize', |
304
|
|
|
) |
305
|
|
|
|
306
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
307
|
|
|
original_image = subject.get(self.image_key) |
308
|
|
|
|
309
|
|
|
label_map_image = subject[self.label_key] |
310
|
|
|
label_map = label_map_image.data |
311
|
|
|
affine = label_map_image.affine |
312
|
|
|
|
313
|
|
|
# Find out if we face a partial-volume image or a label map. |
314
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
315
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
316
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
317
|
|
|
is_discretized = all_discrete and same_num_dims |
318
|
|
|
|
319
|
|
|
if not is_discretized and self.discretize: |
320
|
|
|
# Take label with highest value in voxel |
321
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
322
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
323
|
|
|
label_map[max_label == 0] = -1 |
324
|
|
|
is_discretized = True |
325
|
|
|
|
326
|
|
|
tissues = torch.zeros(1, *label_map_image.spatial_shape).float() |
327
|
|
|
if is_discretized: |
328
|
|
|
labels = label_map.unique().long().tolist() |
329
|
|
|
if -1 in labels: |
330
|
|
|
labels.remove(-1) |
331
|
|
|
else: |
332
|
|
|
labels = range(label_map.shape[0]) |
333
|
|
|
|
334
|
|
|
# Raise error if mean and std are not defined for every label |
335
|
|
|
_check_mean_and_std_length(labels, self.mean, self.std) |
336
|
|
|
|
337
|
|
|
for i, label in enumerate(labels): |
338
|
|
|
if self.used_labels is None or label in self.used_labels: |
339
|
|
|
mean = self.mean[i] |
340
|
|
|
std = self.std[i] |
341
|
|
|
if is_discretized: |
342
|
|
|
mask = label_map == label |
343
|
|
|
else: |
344
|
|
|
mask = label_map[label] |
345
|
|
|
tissues += self.generate_tissue(mask, mean, std) |
346
|
|
|
|
347
|
|
|
else: |
348
|
|
|
# Modify label map to easily compute background mask |
349
|
|
|
if is_discretized: |
350
|
|
|
label_map[label_map == label] = -1 |
351
|
|
|
else: |
352
|
|
|
label_map[label] = 0 |
353
|
|
|
|
354
|
|
|
final_image = ScalarImage(affine=affine, tensor=tissues) |
355
|
|
|
|
356
|
|
|
if original_image is not None: |
357
|
|
|
if is_discretized: |
358
|
|
|
bg_mask = label_map == -1 |
359
|
|
|
else: |
360
|
|
|
bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5 |
361
|
|
|
final_image.data[bg_mask] = original_image.data[bg_mask].float() |
362
|
|
|
|
363
|
|
|
subject.add_image(final_image, self.image_key) |
364
|
|
|
return subject |
365
|
|
|
|
366
|
|
|
@staticmethod |
367
|
|
|
def generate_tissue( |
368
|
|
|
data: TypeData, |
369
|
|
|
mean: float, |
370
|
|
|
std: float, |
371
|
|
|
) -> TypeData: |
372
|
|
|
# Create the simulated tissue using a gaussian random variable |
373
|
|
|
gaussian = torch.randn(data.shape) * std + mean |
374
|
|
|
return gaussian * data |
375
|
|
|
|
376
|
|
|
|
377
|
|
|
def _parse_label_key(label_key: Optional[str]) -> Optional[str]: |
378
|
|
|
if label_key is not None and not isinstance(label_key, str): |
379
|
|
|
message = ( |
380
|
|
|
f'"label_key" must be a string or None, not {type(label_key)}') |
381
|
|
|
raise TypeError(message) |
382
|
|
|
return label_key |
383
|
|
|
|
384
|
|
|
|
385
|
|
|
def _parse_used_labels(used_labels: Sequence[int]) -> Sequence[int]: |
386
|
|
|
if used_labels is None: |
387
|
|
|
return None |
388
|
|
|
check_sequence(used_labels, 'used_labels') |
389
|
|
|
for e in used_labels: |
390
|
|
|
if not isinstance(e, int): |
391
|
|
|
message = ( |
392
|
|
|
'Items in "used_labels" must be integers,' |
393
|
|
|
f' but some are not: {used_labels}' |
394
|
|
|
) |
395
|
|
|
raise ValueError(message) |
396
|
|
|
return used_labels |
397
|
|
|
|
398
|
|
|
|
399
|
|
|
def _check_mean_and_std_length( |
400
|
|
|
labels: Sequence[int], |
401
|
|
|
means: Optional[Sequence[TypeRangeFloat]], |
402
|
|
|
stds: Optional[Sequence[TypeRangeFloat]], |
403
|
|
|
) -> None: |
404
|
|
|
num_labels = len(labels) |
405
|
|
|
if means is not None: |
406
|
|
|
num_means = len(means) |
407
|
|
|
message = ( |
408
|
|
|
'"mean" must define a value for each label but length of "mean"' |
409
|
|
|
f' is {num_means} while {num_labels} labels were found' |
410
|
|
|
) |
411
|
|
|
if num_means != num_labels: |
412
|
|
|
raise RuntimeError(message) |
413
|
|
|
if stds is not None: |
414
|
|
|
num_stds = len(stds) |
415
|
|
|
message = ( |
416
|
|
|
'"std" must define a value for each label but length of "std"' |
417
|
|
|
f' is {num_stds} while {num_labels} labels were found' |
418
|
|
|
) |
419
|
|
|
if num_stds != num_labels: |
420
|
|
|
raise RuntimeError(message) |
421
|
|
|
|