1
|
|
|
from __future__ import annotations |
2
|
|
|
|
3
|
|
|
from collections.abc import Sequence |
4
|
|
|
|
5
|
|
|
import torch |
6
|
|
|
|
7
|
|
|
from ....data.image import LabelMap |
8
|
|
|
from ....data.image import ScalarImage |
9
|
|
|
from ....data.subject import Subject |
10
|
|
|
from ....types import TypeData |
11
|
|
|
from ....types import TypeRangeFloat |
12
|
|
|
from ....utils import check_sequence |
13
|
|
|
from ...intensity_transform import IntensityTransform |
14
|
|
|
from .. import RandomTransform |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
class RandomLabelsToImage(RandomTransform, IntensityTransform): |
18
|
|
|
r"""Randomly generate an image from a segmentation. |
19
|
|
|
|
20
|
|
|
Based on the work by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_ |
21
|
|
|
and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast`_. |
22
|
|
|
|
23
|
|
|
.. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html |
24
|
|
|
|
25
|
|
|
.. _Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast: https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18 |
26
|
|
|
|
27
|
|
|
.. plot:: |
28
|
|
|
|
29
|
|
|
import torch |
30
|
|
|
import torchio as tio |
31
|
|
|
torch.manual_seed(42) |
32
|
|
|
colin = tio.datasets.Colin27(2008) |
33
|
|
|
label_map = colin.cls |
34
|
|
|
colin.remove_image('t1') |
35
|
|
|
colin.remove_image('t2') |
36
|
|
|
colin.remove_image('pd') |
37
|
|
|
downsample = tio.Resample(1) |
38
|
|
|
blurring_transform = tio.RandomBlur(std=0.6) |
39
|
|
|
create_synthetic_image = tio.RandomLabelsToImage( |
40
|
|
|
image_key='synthetic', |
41
|
|
|
ignore_background=True, |
42
|
|
|
) |
43
|
|
|
transform = tio.Compose(( |
44
|
|
|
downsample, |
45
|
|
|
create_synthetic_image, |
46
|
|
|
blurring_transform, |
47
|
|
|
)) |
48
|
|
|
colin_synth = transform(colin) |
49
|
|
|
colin_synth.plot() |
50
|
|
|
|
51
|
|
|
Args: |
52
|
|
|
label_key: String designating the label map in the subject |
53
|
|
|
that will be used to generate the new image. |
54
|
|
|
used_labels: Sequence of integers designating the labels used |
55
|
|
|
to generate the new image. If categorical encoding is used, |
56
|
|
|
:attr:`label_channels` refers to the values of the |
57
|
|
|
categorical encoding. If one hot encoding or partial-volume |
58
|
|
|
label maps are used, :attr:`label_channels` refers to the |
59
|
|
|
channels of the label maps. |
60
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
61
|
|
|
or with voxels from an already existing volume, |
62
|
|
|
see :attr:`image_key`. |
63
|
|
|
image_key: String designating the key to which the new volume will be |
64
|
|
|
saved. If this key corresponds to an already existing volume, |
65
|
|
|
missing voxels will be filled with the corresponding values |
66
|
|
|
in the original volume. |
67
|
|
|
mean: Sequence of means for each label. |
68
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
69
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
70
|
|
|
If ``None``, :attr:`default_mean` range will be used for every |
71
|
|
|
label. |
72
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
73
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
74
|
|
|
same length. |
75
|
|
|
std: Sequence of standard deviations for each label. |
76
|
|
|
For each value :math:`v`, if a tuple :math:`(a, b)` is |
77
|
|
|
provided then :math:`v \sim \mathcal{U}(a, b)`. |
78
|
|
|
If ``None``, :attr:`default_std` range will be used for every |
79
|
|
|
label. |
80
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
81
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
82
|
|
|
same length. |
83
|
|
|
default_mean: Default mean range. |
84
|
|
|
default_std: Default standard deviation range. |
85
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
86
|
|
|
Does not have any effects if not using partial-volume label maps. |
87
|
|
|
Discretization is done taking the class of the highest value per |
88
|
|
|
voxel in the different partial-volume label maps using |
89
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
90
|
|
|
ignore_background: If ``True``, input voxels labeled as ``0`` will not |
91
|
|
|
be modified. |
92
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
93
|
|
|
keyword arguments. |
94
|
|
|
|
95
|
|
|
.. tip:: It is recommended to blur the new images in order to simulate |
96
|
|
|
partial volume effects at the borders of the synthetic structures. See |
97
|
|
|
:class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`. |
98
|
|
|
|
99
|
|
|
Example: |
100
|
|
|
>>> import torchio as tio |
101
|
|
|
>>> subject = tio.datasets.ICBM2009CNonlinearSymmetric() |
102
|
|
|
>>> # Using the default parameters |
103
|
|
|
>>> transform = tio.RandomLabelsToImage(label_key='tissues') |
104
|
|
|
>>> # Using custom mean and std |
105
|
|
|
>>> transform = tio.RandomLabelsToImage( |
106
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0] |
107
|
|
|
... ) |
108
|
|
|
>>> # Discretizing the partial volume maps and blurring the result |
109
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
110
|
|
|
... label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True |
111
|
|
|
... ) |
112
|
|
|
>>> blurring_transform = tio.RandomBlur(std=0.3) |
113
|
|
|
>>> transform = tio.Compose([simulation_transform, blurring_transform]) |
114
|
|
|
>>> transformed = transform(subject) # subject has a new key 'image_from_labels' with the simulated image |
115
|
|
|
>>> # Filling holes of the simulated image with the original T1 image |
116
|
|
|
>>> rescale_transform = tio.RescaleIntensity( |
117
|
|
|
... out_min_max=(0, 1), percentiles=(1, 99)) # Rescale intensity before filling holes |
118
|
|
|
>>> simulation_transform = tio.RandomLabelsToImage( |
119
|
|
|
... label_key='tissues', |
120
|
|
|
... image_key='t1', |
121
|
|
|
... used_labels=[0, 1] |
122
|
|
|
... ) |
123
|
|
|
>>> transform = tio.Compose([rescale_transform, simulation_transform]) |
124
|
|
|
>>> transformed = transform(subject) # subject's key 't1' has been replaced with the simulated image |
125
|
|
|
|
126
|
|
|
.. seealso:: :class:`~torchio.transforms.preprocessing.label.remap_labels.RemapLabels`. |
127
|
|
|
""" |
128
|
|
|
|
129
|
|
|
def __init__( |
130
|
|
|
self, |
131
|
|
|
label_key: str | None = None, |
132
|
|
|
used_labels: Sequence[int] | None = None, |
133
|
|
|
image_key: str = 'image_from_labels', |
134
|
|
|
mean: Sequence[TypeRangeFloat] | None = None, |
135
|
|
|
std: Sequence[TypeRangeFloat] | None = None, |
136
|
|
|
default_mean: TypeRangeFloat = (0.1, 0.9), |
137
|
|
|
default_std: TypeRangeFloat = (0.01, 0.1), |
138
|
|
|
discretize: bool = False, |
139
|
|
|
ignore_background: bool = False, |
140
|
|
|
**kwargs, |
141
|
|
|
): |
142
|
|
|
super().__init__(**kwargs) |
143
|
|
|
self.label_key = _parse_label_key(label_key) |
144
|
|
|
self.used_labels = _parse_used_labels(used_labels) # type: ignore[arg-type] |
145
|
|
|
self.mean, self.std = self.parse_mean_and_std(mean, std) # type: ignore[arg-type,assignment] |
146
|
|
|
self.default_mean = self.parse_gaussian_parameter( |
147
|
|
|
default_mean, |
148
|
|
|
'default_mean', |
149
|
|
|
) |
150
|
|
|
self.default_std = self.parse_gaussian_parameter( |
151
|
|
|
default_std, |
152
|
|
|
'default_std', |
153
|
|
|
) |
154
|
|
|
self.image_key = image_key |
155
|
|
|
self.discretize = discretize |
156
|
|
|
self.ignore_background = ignore_background |
157
|
|
|
|
158
|
|
|
def parse_mean_and_std( |
159
|
|
|
self, |
160
|
|
|
mean: Sequence[TypeRangeFloat], |
161
|
|
|
std: Sequence[TypeRangeFloat], |
162
|
|
|
) -> tuple[list[TypeRangeFloat], list[TypeRangeFloat]]: |
163
|
|
|
if mean is not None: |
164
|
|
|
mean = self.parse_gaussian_parameters(mean, 'mean') |
165
|
|
|
if std is not None: |
166
|
|
|
std = self.parse_gaussian_parameters(std, 'std') |
167
|
|
|
if mean is not None and std is not None: |
168
|
|
|
message = ( |
169
|
|
|
'If both "mean" and "std" are defined they must have the samelength' |
170
|
|
|
) |
171
|
|
|
assert len(mean) == len(std), message |
172
|
|
|
return mean, std |
173
|
|
|
|
174
|
|
|
def parse_gaussian_parameters( |
175
|
|
|
self, |
176
|
|
|
params: Sequence[TypeRangeFloat], |
177
|
|
|
name: str, |
178
|
|
|
) -> list[TypeRangeFloat]: |
179
|
|
|
check_sequence(params, name) |
180
|
|
|
params = [ |
181
|
|
|
self.parse_gaussian_parameter(p, f'{name}[{i}]') |
182
|
|
|
for i, p in enumerate(params) |
183
|
|
|
] |
184
|
|
|
if self.used_labels is not None: |
185
|
|
|
message = ( |
186
|
|
|
f'If both "{name}" and "used_labels" are defined, ' |
187
|
|
|
'they must have the same length' |
188
|
|
|
) |
189
|
|
|
assert len(params) == len(self.used_labels), message |
190
|
|
|
return params |
191
|
|
|
|
192
|
|
|
@staticmethod |
193
|
|
|
def parse_gaussian_parameter( |
194
|
|
|
nums_range: TypeRangeFloat, |
195
|
|
|
name: str, |
196
|
|
|
) -> tuple[float, float]: |
197
|
|
|
if isinstance(nums_range, (int, float)): |
198
|
|
|
return nums_range, nums_range |
199
|
|
|
|
200
|
|
|
if len(nums_range) != 2: |
201
|
|
|
raise ValueError( |
202
|
|
|
f'If {name} is a sequence, it must be of len 2, not {nums_range}', |
203
|
|
|
) |
204
|
|
|
min_value, max_value = nums_range |
205
|
|
|
if min_value > max_value: |
206
|
|
|
raise ValueError( |
207
|
|
|
f'If {name} is a sequence, the second value must be' |
208
|
|
|
f' equal or greater than the first, not {nums_range}', |
209
|
|
|
) |
210
|
|
|
return min_value, max_value |
211
|
|
|
|
212
|
|
|
def _guess_label_key(self, subject: Subject) -> None: |
213
|
|
|
if self.label_key is None: |
214
|
|
|
iterable = subject.get_images_dict(intensity_only=False).items() |
215
|
|
|
for name, image in iterable: |
216
|
|
|
if isinstance(image, LabelMap): |
217
|
|
|
self.label_key = name |
218
|
|
|
break |
219
|
|
|
else: |
220
|
|
|
message = f'No label maps found in subject: {subject}' |
221
|
|
|
raise RuntimeError(message) |
222
|
|
|
|
223
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
224
|
|
|
self._guess_label_key(subject) |
225
|
|
|
|
226
|
|
|
arguments = { |
227
|
|
|
'label_key': self.label_key, |
228
|
|
|
'mean': [], |
229
|
|
|
'std': [], |
230
|
|
|
'image_key': self.image_key, |
231
|
|
|
'used_labels': self.used_labels, |
232
|
|
|
'discretize': self.discretize, |
233
|
|
|
'ignore_background': self.ignore_background, |
234
|
|
|
} |
235
|
|
|
|
236
|
|
|
label_map = subject[self.label_key].data |
237
|
|
|
|
238
|
|
|
# Find out if we face a partial-volume image or a label map. |
239
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
240
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
241
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
242
|
|
|
is_discretized = all_discrete and same_num_dims |
243
|
|
|
|
244
|
|
|
if not is_discretized and self.discretize: |
245
|
|
|
# Take label with highest value in voxel |
246
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
247
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
248
|
|
|
label_map[max_label == 0] = -1 |
249
|
|
|
is_discretized = True |
250
|
|
|
|
251
|
|
|
if is_discretized: |
252
|
|
|
labels = label_map.unique().long().tolist() |
253
|
|
|
if -1 in labels: |
254
|
|
|
labels.remove(-1) |
255
|
|
|
else: |
256
|
|
|
labels = range(label_map.shape[0]) |
257
|
|
|
|
258
|
|
|
# Raise error if mean and std are not defined for every label |
259
|
|
|
_check_mean_and_std_length(labels, self.mean, self.std) # type: ignore[arg-type] |
260
|
|
|
|
261
|
|
|
for label in labels: |
262
|
|
|
mean, std = self.get_params(label) |
263
|
|
|
means = arguments['mean'] |
264
|
|
|
stds = arguments['std'] |
265
|
|
|
assert isinstance(means, list) |
266
|
|
|
assert isinstance(stds, list) |
267
|
|
|
means.append(mean) |
268
|
|
|
stds.append(std) |
269
|
|
|
|
270
|
|
|
transform = LabelsToImage(**self.add_base_args(arguments)) |
271
|
|
|
transformed = transform(subject) |
272
|
|
|
assert isinstance(transformed, Subject) |
273
|
|
|
return transformed |
274
|
|
|
|
275
|
|
|
def get_params(self, label: int) -> tuple[float, float]: |
276
|
|
|
if self.mean is None: |
277
|
|
|
mean_range = self.default_mean |
278
|
|
|
else: |
279
|
|
|
assert isinstance(self.mean, Sequence) |
280
|
|
|
mean_range = self.mean[label] |
281
|
|
|
if self.std is None: |
282
|
|
|
std_range = self.default_std |
283
|
|
|
else: |
284
|
|
|
std_range = self.std[label] |
285
|
|
|
mean = self.sample_uniform(*mean_range) # type: ignore[misc] |
286
|
|
|
std = self.sample_uniform(*std_range) # type: ignore[misc] |
287
|
|
|
return mean, std |
288
|
|
|
|
289
|
|
|
|
290
|
|
|
class LabelsToImage(IntensityTransform): |
291
|
|
|
r"""Generate an image from a segmentation. |
292
|
|
|
|
293
|
|
|
Args: |
294
|
|
|
label_key: String designating the label map in the subject |
295
|
|
|
that will be used to generate the new image. |
296
|
|
|
used_labels: Sequence of integers designating the labels used |
297
|
|
|
to generate the new image. If categorical encoding is used, |
298
|
|
|
:attr:`label_channels` refers to the values of the |
299
|
|
|
categorical encoding. If one hot encoding or partial-volume |
300
|
|
|
label maps are used, :attr:`label_channels` refers to the |
301
|
|
|
channels of the label maps. |
302
|
|
|
Default uses all labels. Missing voxels will be filled with zero |
303
|
|
|
or with voxels from an already existing volume, |
304
|
|
|
see :attr:`image_key`. |
305
|
|
|
image_key: String designating the key to which the new volume will be |
306
|
|
|
saved. If this key corresponds to an already existing volume, |
307
|
|
|
missing voxels will be filled with the corresponding values |
308
|
|
|
in the original volume. |
309
|
|
|
mean: Sequence of means for each label. |
310
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
311
|
|
|
:attr:`mean` and :attr:`label_channels` must have the |
312
|
|
|
same length. |
313
|
|
|
std: Sequence of standard deviations for each label. |
314
|
|
|
If not ``None`` and :attr:`label_channels` is not ``None``, |
315
|
|
|
:attr:`std` and :attr:`label_channels` must have the |
316
|
|
|
same length. |
317
|
|
|
discretize: If ``True``, partial-volume label maps will be discretized. |
318
|
|
|
Does not have any effects if not using partial-volume label maps. |
319
|
|
|
Discretization is done taking the class of the highest value per |
320
|
|
|
voxel in the different partial-volume label maps using |
321
|
|
|
:func:`torch.argmax()` on the channel dimension (i.e. 0). |
322
|
|
|
ignore_background: If ``True``, input voxels labeled as ``0`` will not |
323
|
|
|
be modified. |
324
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
325
|
|
|
keyword arguments. |
326
|
|
|
|
327
|
|
|
.. note:: It is recommended to blur the new images to make the result more |
328
|
|
|
realistic. See |
329
|
|
|
:class:`~torchio.transforms.augmentation.RandomBlur`. |
330
|
|
|
""" |
331
|
|
|
|
332
|
|
|
def __init__( |
333
|
|
|
self, |
334
|
|
|
label_key: str, |
335
|
|
|
mean: Sequence[float] | None, |
336
|
|
|
std: Sequence[float] | None, |
337
|
|
|
image_key: str = 'image_from_labels', |
338
|
|
|
used_labels: Sequence[int] | None = None, |
339
|
|
|
ignore_background: bool = False, |
340
|
|
|
discretize: bool = False, |
341
|
|
|
**kwargs, |
342
|
|
|
): |
343
|
|
|
super().__init__(**kwargs) |
344
|
|
|
self.label_key = _parse_label_key(label_key) |
345
|
|
|
self.used_labels = _parse_used_labels(used_labels) |
346
|
|
|
self.mean, self.std = mean, std # type: ignore[assignment] |
347
|
|
|
self.image_key = image_key |
348
|
|
|
self.ignore_background = ignore_background |
349
|
|
|
self.discretize = discretize |
350
|
|
|
self.args_names = [ |
351
|
|
|
'label_key', |
352
|
|
|
'mean', |
353
|
|
|
'std', |
354
|
|
|
'image_key', |
355
|
|
|
'used_labels', |
356
|
|
|
'ignore_background', |
357
|
|
|
'discretize', |
358
|
|
|
] |
359
|
|
|
|
360
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
361
|
|
|
original_image = subject.get(self.image_key) |
362
|
|
|
|
363
|
|
|
label_map_image = subject[self.label_key] |
364
|
|
|
label_map = label_map_image.data |
365
|
|
|
affine = label_map_image.affine |
366
|
|
|
|
367
|
|
|
# Find out if we face a partial-volume image or a label map. |
368
|
|
|
# One-hot-encoded label map is considered as a partial-volume image |
369
|
|
|
all_discrete = label_map.eq(label_map.float().round()).all() |
370
|
|
|
same_num_dims = label_map.squeeze().dim() < label_map.dim() |
371
|
|
|
is_discretized = all_discrete and same_num_dims |
372
|
|
|
|
373
|
|
|
if not is_discretized and self.discretize: |
374
|
|
|
# Take label with highest value in voxel |
375
|
|
|
max_label, label_map = label_map.max(dim=0, keepdim=True) |
376
|
|
|
# Remove values where all labels are 0 (i.e. missing labels) |
377
|
|
|
label_map[max_label == 0] = -1 |
378
|
|
|
is_discretized = True |
379
|
|
|
|
380
|
|
|
tissues = torch.zeros(1, *label_map_image.spatial_shape).float() |
381
|
|
|
if is_discretized: |
382
|
|
|
labels_in_image = label_map.unique().long().tolist() |
383
|
|
|
if -1 in labels_in_image: |
384
|
|
|
labels_in_image.remove(-1) |
385
|
|
|
else: |
386
|
|
|
labels_in_image = range(label_map.shape[0]) |
387
|
|
|
|
388
|
|
|
# Raise error if mean and std are not defined for every label |
389
|
|
|
_check_mean_and_std_length( |
390
|
|
|
labels_in_image, |
391
|
|
|
self.mean, # type: ignore[arg-type] |
392
|
|
|
self.std, |
393
|
|
|
) |
394
|
|
|
|
395
|
|
|
for i, label in enumerate(labels_in_image): |
396
|
|
|
if label == 0 and self.ignore_background: |
397
|
|
|
continue |
398
|
|
|
if self.used_labels is None or label in self.used_labels: |
399
|
|
|
assert isinstance(self.mean, Sequence) |
400
|
|
|
assert isinstance(self.std, Sequence) |
401
|
|
|
mean = self.mean[i] |
402
|
|
|
std = self.std[i] |
403
|
|
|
if is_discretized: |
404
|
|
|
mask = label_map == label |
405
|
|
|
else: |
406
|
|
|
mask = label_map[label] |
407
|
|
|
tissues += self.generate_tissue(mask, mean, std) |
408
|
|
|
|
409
|
|
|
else: |
410
|
|
|
# Modify label map to easily compute background mask |
411
|
|
|
if is_discretized: |
412
|
|
|
label_map[label_map == label] = -1 |
413
|
|
|
else: |
414
|
|
|
label_map[label] = 0 |
415
|
|
|
|
416
|
|
|
final_image = ScalarImage(affine=affine, tensor=tissues) |
417
|
|
|
|
418
|
|
|
if original_image is not None: |
419
|
|
|
if is_discretized: |
420
|
|
|
bg_mask = label_map == -1 |
421
|
|
|
else: |
422
|
|
|
bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5 |
423
|
|
|
final_image.data[bg_mask] = original_image.data[bg_mask].float() |
424
|
|
|
|
425
|
|
|
subject.add_image(final_image, self.image_key) |
426
|
|
|
return subject |
427
|
|
|
|
428
|
|
|
@staticmethod |
429
|
|
|
def generate_tissue( |
430
|
|
|
data: TypeData, |
431
|
|
|
mean: float, |
432
|
|
|
std: float, |
433
|
|
|
) -> TypeData: |
434
|
|
|
# Create the simulated tissue using a gaussian random variable |
435
|
|
|
gaussian = torch.randn(data.shape) * std + mean |
436
|
|
|
return gaussian * data |
437
|
|
|
|
438
|
|
|
|
439
|
|
|
def _parse_label_key(label_key: str | None) -> str | None: |
440
|
|
|
if label_key is not None and not isinstance(label_key, str): |
441
|
|
|
message = f'"label_key" must be a string or None, not {type(label_key)}' |
442
|
|
|
raise TypeError(message) |
443
|
|
|
return label_key |
444
|
|
|
|
445
|
|
|
|
446
|
|
|
def _parse_used_labels( |
447
|
|
|
used_labels: Sequence[int] | None, |
448
|
|
|
) -> Sequence[int] | None: |
449
|
|
|
if used_labels is None: |
450
|
|
|
return None |
451
|
|
|
check_sequence(used_labels, 'used_labels') |
452
|
|
|
for e in used_labels: |
453
|
|
|
if not isinstance(e, int): |
454
|
|
|
message = ( |
455
|
|
|
'Items in "used_labels" must be integers,' |
456
|
|
|
f' but some are not: {used_labels}' |
457
|
|
|
) |
458
|
|
|
raise ValueError(message) |
459
|
|
|
return used_labels |
460
|
|
|
|
461
|
|
|
|
462
|
|
|
def _check_mean_and_std_length( |
463
|
|
|
labels: Sequence[int], |
464
|
|
|
means: Sequence[TypeRangeFloat] | None, |
465
|
|
|
stds: Sequence[TypeRangeFloat] | None, |
466
|
|
|
) -> None: |
467
|
|
|
num_labels = len(labels) |
468
|
|
|
if means is not None: |
469
|
|
|
num_means = len(means) |
470
|
|
|
message = ( |
471
|
|
|
'"mean" must define a value for each label but length of "mean"' |
472
|
|
|
f' is {num_means} while {num_labels} labels were found' |
473
|
|
|
) |
474
|
|
|
if num_means != num_labels: |
475
|
|
|
raise RuntimeError(message) |
476
|
|
|
if stds is not None: |
477
|
|
|
num_stds = len(stds) |
478
|
|
|
message = ( |
479
|
|
|
'"std" must define a value for each label but length of "std"' |
480
|
|
|
f' is {num_stds} while {num_labels} labels were found' |
481
|
|
|
) |
482
|
|
|
if num_stds != num_labels: |
483
|
|
|
raise RuntimeError(message) |
484
|
|
|
|