|
1
|
|
|
from __future__ import annotations |
|
2
|
|
|
|
|
3
|
|
|
from collections.abc import Sequence |
|
4
|
|
|
|
|
5
|
|
|
import torch |
|
6
|
|
|
|
|
7
|
|
|
from ....data.image import LabelMap |
|
8
|
|
|
from ....data.image import ScalarImage |
|
9
|
|
|
from ....data.subject import Subject |
|
10
|
|
|
from ....types import TypeData |
|
11
|
|
|
from ....types import TypeRangeFloat |
|
12
|
|
|
from ....utils import check_sequence |
|
13
|
|
|
from ...intensity_transform import IntensityTransform |
|
14
|
|
|
from .. import RandomTransform |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
class RandomLabelsToImage(RandomTransform, IntensityTransform): |
|
18
|
|
|
r"""Randomly generate an image from a segmentation. |
|
19
|
|
|
|
|
20
|
|
|
Based on the work by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_ |
|
21
|
|
|
and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast`_. |
|
22
|
|
|
|
|
23
|
|
|
.. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html |
|
24
|
|
|
|
|
25
|
|
|
.. _Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast: https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18 |
|
26
|
|
|
|
|
27
|
|
|
.. plot:: |
|
28
|
|
|
|
|
29
|
|
|
import torch |
|
30
|
|
|
import torchio as tio |
|
31
|
|
|
torch.manual_seed(42) |
|
32
|
|
|
colin = tio.datasets.Colin27(2008) |
|
33
|
|
|
label_map = colin.cls |
|
34
|
|
|
colin.remove_image('t1') |
|
35
|
|
|
colin.remove_image('t2') |
|
36
|
|
|
colin.remove_image('pd') |
|
37
|
|
|
downsample = tio.Resample(1) |
|
38
|
|
|
blurring_transform = tio.RandomBlur(std=0.6) |
|
39
|
|
|
create_synthetic_image = tio.RandomLabelsToImage( |
|
40
|
|
|
image_key='synthetic', |
|
41
|
|
|
ignore_background=True, |
|
42
|
|
|
) |
|
43
|
|
|
transform = tio.Compose(( |
|
44
|
|
|
downsample, |
|
45
|
|
|
create_synthetic_image, |
|
46
|
|
|
blurring_transform, |
|
47
|
|
|
)) |
|
48
|
|
|
colin_synth = transform(colin) |
|
49
|
|
|
colin_synth.plot() |
|
50
|
|
|
|
|
51
|
|
|
Args: |
|
52
|
|
|
label_key: String designating the label map in the subject |
|
53
|
|
|
that will be used to generate the new image. |
|
54
|
|
|
used_labels: Sequence of integers designating the labels used |
|
55
|
|
|
to generate the new image. If categorical encoding is used, |
|
56
|
|
|
:attr:`label_channels` refers to the values of the |
|
57
|
|
|
categorical encoding. If one hot encoding or partial-volume |
|
58
|
|
|
label maps are used, :attr:`label_channels` refers to the |
|
59
|
|
|
channels of the label maps. |
|
60
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
|
61
|
|
|
or with voxels from an already existing volume, |
|
62
|
|
|
see :attr:`image_key`. |
|
63
|
|
|
image_key: String designating the key to which the new volume will be |
|
64
|
|
|
saved. If this key corresponds to an already existing volume, |
|
65
|
|
|
missing voxels will be filled with the corresponding values |
|
66
|
|
|
in the original volume. |
|
67
|
|
|
mean: Sequence of means for each label. |
|
68
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
|
69
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
|
70
|
|
|
If ``None``, :attr:`default_mean` range will be used for every |
|
71
|
|
|
label. |
|
72
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
|
73
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
|
74
|
|
|
same length. |
|
75
|
|
|
std: Sequence of standard deviations for each label. |
|
76
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
|
77
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
|
78
|
|
|
If ``None``, :attr:`default_std` range will be used for every |
|
79
|
|
|
label. |
|
80
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
|
81
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
|
82
|
|
|
same length. |
|
83
|
|
|
default_mean: Default mean range. |
|
84
|
|
|
default_std: Default standard deviation range. |
|
85
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
|
86
|
|
|
Does not have any effects if not using partial-volume label maps. |
|
87
|
|
|
Discretization is done taking the class of the highest value per |
|
88
|
|
|
voxel in the different partial-volume label maps using |
|
89
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
|
90
|
|
|
ignore_background: If ``True``, input voxels labeled as ``0`` will not |
|
91
|
|
|
be modified. |
|
92
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
|
93
|
|
|
keyword arguments. |
|
94
|
|
|
|
|
95
|
|
|
.. tip:: It is recommended to blur the new images in order to simulate |
|
96
|
|
|
partial volume effects at the borders of the synthetic structures. See |
|
97
|
|
|
:class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`. |
|
98
|
|
|
|
|
99
|
|
|
Example: |
|
100
|
|
|
>>> import torchio as tio |
|
101
|
|
|
>>> subject = tio.datasets.ICBM2009CNonlinearSymmetric() |
|
102
|
|
|
>>> # Using the default parameters |
|
103
|
|
|
>>> transform = tio.RandomLabelsToImage(label_key='tissues') |
|
104
|
|
|
>>> # Using custom mean and std |
|
105
|
|
|
>>> transform = tio.RandomLabelsToImage( |
|
106
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0] |
|
107
|
|
|
... ) |
|
108
|
|
|
>>> # Discretizing the partial volume maps and blurring the result |
|
109
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
|
110
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True |
|
111
|
|
|
... ) |
|
112
|
|
|
>>> blurring_transform = tio.RandomBlur(std=0.3) |
|
113
|
|
|
>>> transform = tio.Compose([simulation_transform, blurring_transform]) |
|
114
|
|
|
>>> transformed = transform(subject) # subject has a new key 'image_from_labels' with the simulated image |
|
115
|
|
|
>>> # Filling holes of the simulated image with the original T1 image |
|
116
|
|
|
>>> rescale_transform = tio.RescaleIntensity( |
|
117
|
|
|
... out_min_max=(0, 1), percentiles=(1, 99)) # Rescale intensity before filling holes |
|
118
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
|
119
|
|
|
... label_key='tissues', |
|
120
|
|
|
... image_key='t1', |
|
121
|
|
|
... used_labels=[0, 1] |
|
122
|
|
|
... ) |
|
123
|
|
|
>>> transform = tio.Compose([rescale_transform, simulation_transform]) |
|
124
|
|
|
>>> transformed = transform(subject) # subject's key 't1' has been replaced with the simulated image |
|
125
|
|
|
|
|
126
|
|
|
.. seealso:: :class:`~torchio.transforms.preprocessing.label.remap_labels.RemapLabels`. |
|
127
|
|
|
""" |
|
128
|
|
|
|
|
129
|
|
|
def __init__( |
|
130
|
|
|
self, |
|
131
|
|
|
label_key: str | None = None, |
|
132
|
|
|
used_labels: Sequence[int] | None = None, |
|
133
|
|
|
image_key: str = 'image_from_labels', |
|
134
|
|
|
mean: Sequence[TypeRangeFloat] | None = None, |
|
135
|
|
|
std: Sequence[TypeRangeFloat] | None = None, |
|
136
|
|
|
default_mean: TypeRangeFloat = (0.1, 0.9), |
|
137
|
|
|
default_std: TypeRangeFloat = (0.01, 0.1), |
|
138
|
|
|
discretize: bool = False, |
|
139
|
|
|
ignore_background: bool = False, |
|
140
|
|
|
**kwargs, |
|
141
|
|
|
): |
|
142
|
|
|
super().__init__(**kwargs) |
|
143
|
|
|
self.label_key = _parse_label_key(label_key) |
|
144
|
|
|
self.used_labels = _parse_used_labels(used_labels) # type: ignore[arg-type] |
|
145
|
|
|
self.mean, self.std = self.parse_mean_and_std(mean, std) # type: ignore[arg-type,assignment] |
|
146
|
|
|
self.default_mean = self.parse_gaussian_parameter( |
|
147
|
|
|
default_mean, |
|
148
|
|
|
'default_mean', |
|
149
|
|
|
) |
|
150
|
|
|
self.default_std = self.parse_gaussian_parameter( |
|
151
|
|
|
default_std, |
|
152
|
|
|
'default_std', |
|
153
|
|
|
) |
|
154
|
|
|
self.image_key = image_key |
|
155
|
|
|
self.discretize = discretize |
|
156
|
|
|
self.ignore_background = ignore_background |
|
157
|
|
|
|
|
158
|
|
|
def parse_mean_and_std( |
|
159
|
|
|
self, |
|
160
|
|
|
mean: Sequence[TypeRangeFloat], |
|
161
|
|
|
std: Sequence[TypeRangeFloat], |
|
162
|
|
|
) -> tuple[list[TypeRangeFloat], list[TypeRangeFloat]]: |
|
163
|
|
|
if mean is not None: |
|
164
|
|
|
mean = self.parse_gaussian_parameters(mean, 'mean') |
|
165
|
|
|
if std is not None: |
|
166
|
|
|
std = self.parse_gaussian_parameters(std, 'std') |
|
167
|
|
|
if mean is not None and std is not None: |
|
168
|
|
|
message = ( |
|
169
|
|
|
'If both "mean" and "std" are defined they must have the samelength' |
|
170
|
|
|
) |
|
171
|
|
|
assert len(mean) == len(std), message |
|
172
|
|
|
return mean, std |
|
173
|
|
|
|
|
174
|
|
|
def parse_gaussian_parameters( |
|
175
|
|
|
self, |
|
176
|
|
|
params: Sequence[TypeRangeFloat], |
|
177
|
|
|
name: str, |
|
178
|
|
|
) -> list[TypeRangeFloat]: |
|
179
|
|
|
check_sequence(params, name) |
|
180
|
|
|
params = [ |
|
181
|
|
|
self.parse_gaussian_parameter(p, f'{name}[{i}]') |
|
182
|
|
|
for i, p in enumerate(params) |
|
183
|
|
|
] |
|
184
|
|
|
if self.used_labels is not None: |
|
185
|
|
|
message = ( |
|
186
|
|
|
f'If both "{name}" and "used_labels" are defined, ' |
|
187
|
|
|
'they must have the same length' |
|
188
|
|
|
) |
|
189
|
|
|
assert len(params) == len(self.used_labels), message |
|
190
|
|
|
return params |
|
191
|
|
|
|
|
192
|
|
|
@staticmethod |
|
193
|
|
|
def parse_gaussian_parameter( |
|
194
|
|
|
nums_range: TypeRangeFloat, |
|
195
|
|
|
name: str, |
|
196
|
|
|
) -> tuple[float, float]: |
|
197
|
|
|
if isinstance(nums_range, (int, float)): |
|
198
|
|
|
return nums_range, nums_range |
|
199
|
|
|
|
|
200
|
|
|
if len(nums_range) != 2: |
|
201
|
|
|
raise ValueError( |
|
202
|
|
|
f'If {name} is a sequence, it must be of len 2, not {nums_range}', |
|
203
|
|
|
) |
|
204
|
|
|
min_value, max_value = nums_range |
|
205
|
|
|
if min_value > max_value: |
|
206
|
|
|
raise ValueError( |
|
207
|
|
|
f'If {name} is a sequence, the second value must be' |
|
208
|
|
|
f' equal or greater than the first, not {nums_range}', |
|
209
|
|
|
) |
|
210
|
|
|
return min_value, max_value |
|
211
|
|
|
|
|
212
|
|
|
def _guess_label_key(self, subject: Subject) -> None: |
|
213
|
|
|
if self.label_key is None: |
|
214
|
|
|
iterable = subject.get_images_dict(intensity_only=False).items() |
|
215
|
|
|
for name, image in iterable: |
|
216
|
|
|
if isinstance(image, LabelMap): |
|
217
|
|
|
self.label_key = name |
|
218
|
|
|
break |
|
219
|
|
|
else: |
|
220
|
|
|
message = f'No label maps found in subject: {subject}' |
|
221
|
|
|
raise RuntimeError(message) |
|
222
|
|
|
|
|
223
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
224
|
|
|
self._guess_label_key(subject) |
|
225
|
|
|
|
|
226
|
|
|
arguments = { |
|
227
|
|
|
'label_key': self.label_key, |
|
228
|
|
|
'mean': [], |
|
229
|
|
|
'std': [], |
|
230
|
|
|
'image_key': self.image_key, |
|
231
|
|
|
'used_labels': self.used_labels, |
|
232
|
|
|
'discretize': self.discretize, |
|
233
|
|
|
'ignore_background': self.ignore_background, |
|
234
|
|
|
} |
|
235
|
|
|
|
|
236
|
|
|
label_map = subject[self.label_key].data |
|
237
|
|
|
|
|
238
|
|
|
# Find out if we face a partial-volume image or a label map. |
|
239
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
|
240
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
|
241
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
|
242
|
|
|
is_discretized = all_discrete and same_num_dims |
|
243
|
|
|
|
|
244
|
|
|
if not is_discretized and self.discretize: |
|
245
|
|
|
# Take label with highest value in voxel |
|
246
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
|
247
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
|
248
|
|
|
label_map[max_label == 0] = -1 |
|
249
|
|
|
is_discretized = True |
|
250
|
|
|
|
|
251
|
|
|
if is_discretized: |
|
252
|
|
|
labels = label_map.unique().long().tolist() |
|
253
|
|
|
if -1 in labels: |
|
254
|
|
|
labels.remove(-1) |
|
255
|
|
|
else: |
|
256
|
|
|
labels = range(label_map.shape[0]) |
|
257
|
|
|
|
|
258
|
|
|
# Raise error if mean and std are not defined for every label |
|
259
|
|
|
_check_mean_and_std_length(labels, self.mean, self.std) # type: ignore[arg-type] |
|
260
|
|
|
|
|
261
|
|
|
for label in labels: |
|
262
|
|
|
mean, std = self.get_params(label) |
|
263
|
|
|
means = arguments['mean'] |
|
264
|
|
|
stds = arguments['std'] |
|
265
|
|
|
assert isinstance(means, list) |
|
266
|
|
|
assert isinstance(stds, list) |
|
267
|
|
|
means.append(mean) |
|
268
|
|
|
stds.append(std) |
|
269
|
|
|
|
|
270
|
|
|
transform = LabelsToImage(**self.add_base_args(arguments)) |
|
271
|
|
|
transformed = transform(subject) |
|
272
|
|
|
assert isinstance(transformed, Subject) |
|
273
|
|
|
return transformed |
|
274
|
|
|
|
|
275
|
|
|
def get_params(self, label: int) -> tuple[float, float]: |
|
276
|
|
|
if self.mean is None: |
|
277
|
|
|
mean_range = self.default_mean |
|
278
|
|
|
else: |
|
279
|
|
|
assert isinstance(self.mean, Sequence) |
|
280
|
|
|
mean_range = self.mean[label] |
|
281
|
|
|
if self.std is None: |
|
282
|
|
|
std_range = self.default_std |
|
283
|
|
|
else: |
|
284
|
|
|
std_range = self.std[label] |
|
285
|
|
|
mean = self.sample_uniform(*mean_range) # type: ignore[misc] |
|
286
|
|
|
std = self.sample_uniform(*std_range) # type: ignore[misc] |
|
287
|
|
|
return mean, std |
|
288
|
|
|
|
|
289
|
|
|
|
|
290
|
|
|
class LabelsToImage(IntensityTransform): |
|
291
|
|
|
r"""Generate an image from a segmentation. |
|
292
|
|
|
|
|
293
|
|
|
Args: |
|
294
|
|
|
label_key: String designating the label map in the subject |
|
295
|
|
|
that will be used to generate the new image. |
|
296
|
|
|
used_labels: Sequence of integers designating the labels used |
|
297
|
|
|
to generate the new image. If categorical encoding is used, |
|
298
|
|
|
:attr:`label_channels` refers to the values of the |
|
299
|
|
|
categorical encoding. If one hot encoding or partial-volume |
|
300
|
|
|
label maps are used, :attr:`label_channels` refers to the |
|
301
|
|
|
channels of the label maps. |
|
302
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
|
303
|
|
|
or with voxels from an already existing volume, |
|
304
|
|
|
see :attr:`image_key`. |
|
305
|
|
|
image_key: String designating the key to which the new volume will be |
|
306
|
|
|
saved. If this key corresponds to an already existing volume, |
|
307
|
|
|
missing voxels will be filled with the corresponding values |
|
308
|
|
|
in the original volume. |
|
309
|
|
|
mean: Sequence of means for each label. |
|
310
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
|
311
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
|
312
|
|
|
same length. |
|
313
|
|
|
std: Sequence of standard deviations for each label. |
|
314
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
|
315
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
|
316
|
|
|
same length. |
|
317
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
|
318
|
|
|
Does not have any effects if not using partial-volume label maps. |
|
319
|
|
|
Discretization is done taking the class of the highest value per |
|
320
|
|
|
voxel in the different partial-volume label maps using |
|
321
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
|
322
|
|
|
ignore_background: If ``True``, input voxels labeled as ``0`` will not |
|
323
|
|
|
be modified. |
|
324
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
|
325
|
|
|
keyword arguments. |
|
326
|
|
|
|
|
327
|
|
|
.. note:: It is recommended to blur the new images to make the result more |
|
328
|
|
|
realistic. See |
|
329
|
|
|
:class:`~torchio.transforms.augmentation.RandomBlur`. |
|
330
|
|
|
""" |
|
331
|
|
|
|
|
332
|
|
|
def __init__( |
|
333
|
|
|
self, |
|
334
|
|
|
label_key: str, |
|
335
|
|
|
mean: Sequence[float] | None, |
|
336
|
|
|
std: Sequence[float] | None, |
|
337
|
|
|
image_key: str = 'image_from_labels', |
|
338
|
|
|
used_labels: Sequence[int] | None = None, |
|
339
|
|
|
ignore_background: bool = False, |
|
340
|
|
|
discretize: bool = False, |
|
341
|
|
|
**kwargs, |
|
342
|
|
|
): |
|
343
|
|
|
super().__init__(**kwargs) |
|
344
|
|
|
self.label_key = _parse_label_key(label_key) |
|
345
|
|
|
self.used_labels = _parse_used_labels(used_labels) |
|
346
|
|
|
self.mean, self.std = mean, std # type: ignore[assignment] |
|
347
|
|
|
self.image_key = image_key |
|
348
|
|
|
self.ignore_background = ignore_background |
|
349
|
|
|
self.discretize = discretize |
|
350
|
|
|
self.args_names = [ |
|
351
|
|
|
'label_key', |
|
352
|
|
|
'mean', |
|
353
|
|
|
'std', |
|
354
|
|
|
'image_key', |
|
355
|
|
|
'used_labels', |
|
356
|
|
|
'ignore_background', |
|
357
|
|
|
'discretize', |
|
358
|
|
|
] |
|
359
|
|
|
|
|
360
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
361
|
|
|
original_image = subject.get(self.image_key) |
|
362
|
|
|
|
|
363
|
|
|
label_map_image = subject[self.label_key] |
|
364
|
|
|
label_map = label_map_image.data |
|
365
|
|
|
affine = label_map_image.affine |
|
366
|
|
|
|
|
367
|
|
|
# Find out if we face a partial-volume image or a label map. |
|
368
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
|
369
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
|
370
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
|
371
|
|
|
is_discretized = all_discrete and same_num_dims |
|
372
|
|
|
|
|
373
|
|
|
if not is_discretized and self.discretize: |
|
374
|
|
|
# Take label with highest value in voxel |
|
375
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
|
376
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
|
377
|
|
|
label_map[max_label == 0] = -1 |
|
378
|
|
|
is_discretized = True |
|
379
|
|
|
|
|
380
|
|
|
tissues = torch.zeros(1, *label_map_image.spatial_shape).float() |
|
381
|
|
|
if is_discretized: |
|
382
|
|
|
labels_in_image = label_map.unique().long().tolist() |
|
383
|
|
|
if -1 in labels_in_image: |
|
384
|
|
|
labels_in_image.remove(-1) |
|
385
|
|
|
else: |
|
386
|
|
|
labels_in_image = range(label_map.shape[0]) |
|
387
|
|
|
|
|
388
|
|
|
# Raise error if mean and std are not defined for every label |
|
389
|
|
|
_check_mean_and_std_length( |
|
390
|
|
|
labels_in_image, |
|
391
|
|
|
self.mean, # type: ignore[arg-type] |
|
392
|
|
|
self.std, |
|
393
|
|
|
) |
|
394
|
|
|
|
|
395
|
|
|
for i, label in enumerate(labels_in_image): |
|
396
|
|
|
if label == 0 and self.ignore_background: |
|
397
|
|
|
continue |
|
398
|
|
|
if self.used_labels is None or label in self.used_labels: |
|
399
|
|
|
assert isinstance(self.mean, Sequence) |
|
400
|
|
|
assert isinstance(self.std, Sequence) |
|
401
|
|
|
mean = self.mean[i] |
|
402
|
|
|
std = self.std[i] |
|
403
|
|
|
if is_discretized: |
|
404
|
|
|
mask = label_map == label |
|
405
|
|
|
else: |
|
406
|
|
|
mask = label_map[label] |
|
407
|
|
|
tissues += self.generate_tissue(mask, mean, std) |
|
408
|
|
|
|
|
409
|
|
|
else: |
|
410
|
|
|
# Modify label map to easily compute background mask |
|
411
|
|
|
if is_discretized: |
|
412
|
|
|
label_map[label_map == label] = -1 |
|
413
|
|
|
else: |
|
414
|
|
|
label_map[label] = 0 |
|
415
|
|
|
|
|
416
|
|
|
final_image = ScalarImage(affine=affine, tensor=tissues) |
|
417
|
|
|
|
|
418
|
|
|
if original_image is not None: |
|
419
|
|
|
if is_discretized: |
|
420
|
|
|
bg_mask = label_map == -1 |
|
421
|
|
|
else: |
|
422
|
|
|
bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5 |
|
423
|
|
|
final_image.data[bg_mask] = original_image.data[bg_mask].float() |
|
424
|
|
|
|
|
425
|
|
|
subject.add_image(final_image, self.image_key) |
|
426
|
|
|
return subject |
|
427
|
|
|
|
|
428
|
|
|
@staticmethod |
|
429
|
|
|
def generate_tissue( |
|
430
|
|
|
data: TypeData, |
|
431
|
|
|
mean: float, |
|
432
|
|
|
std: float, |
|
433
|
|
|
) -> TypeData: |
|
434
|
|
|
# Create the simulated tissue using a gaussian random variable |
|
435
|
|
|
gaussian = torch.randn(data.shape) * std + mean |
|
436
|
|
|
return gaussian * data |
|
437
|
|
|
|
|
438
|
|
|
|
|
439
|
|
|
def _parse_label_key(label_key: str | None) -> str | None: |
|
440
|
|
|
if label_key is not None and not isinstance(label_key, str): |
|
441
|
|
|
message = f'"label_key" must be a string or None, not {type(label_key)}' |
|
442
|
|
|
raise TypeError(message) |
|
443
|
|
|
return label_key |
|
444
|
|
|
|
|
445
|
|
|
|
|
446
|
|
|
def _parse_used_labels( |
|
447
|
|
|
used_labels: Sequence[int] | None, |
|
448
|
|
|
) -> Sequence[int] | None: |
|
449
|
|
|
if used_labels is None: |
|
450
|
|
|
return None |
|
451
|
|
|
check_sequence(used_labels, 'used_labels') |
|
452
|
|
|
for e in used_labels: |
|
453
|
|
|
if not isinstance(e, int): |
|
454
|
|
|
message = ( |
|
455
|
|
|
'Items in "used_labels" must be integers,' |
|
456
|
|
|
f' but some are not: {used_labels}' |
|
457
|
|
|
) |
|
458
|
|
|
raise ValueError(message) |
|
459
|
|
|
return used_labels |
|
460
|
|
|
|
|
461
|
|
|
|
|
462
|
|
|
def _check_mean_and_std_length( |
|
463
|
|
|
labels: Sequence[int], |
|
464
|
|
|
means: Sequence[TypeRangeFloat] | None, |
|
465
|
|
|
stds: Sequence[TypeRangeFloat] | None, |
|
466
|
|
|
) -> None: |
|
467
|
|
|
num_labels = len(labels) |
|
468
|
|
|
if means is not None: |
|
469
|
|
|
num_means = len(means) |
|
470
|
|
|
message = ( |
|
471
|
|
|
'"mean" must define a value for each label but length of "mean"' |
|
472
|
|
|
f' is {num_means} while {num_labels} labels were found' |
|
473
|
|
|
) |
|
474
|
|
|
if num_means != num_labels: |
|
475
|
|
|
raise RuntimeError(message) |
|
476
|
|
|
if stds is not None: |
|
477
|
|
|
num_stds = len(stds) |
|
478
|
|
|
message = ( |
|
479
|
|
|
'"std" must define a value for each label but length of "std"' |
|
480
|
|
|
f' is {num_stds} while {num_labels} labels were found' |
|
481
|
|
|
) |
|
482
|
|
|
if num_stds != num_labels: |
|
483
|
|
|
raise RuntimeError(message) |
|
484
|
|
|
|