torchio.transforms.transform   F
last analyzed

Complexity

Total Complexity 109

Size/Duplication

Total Lines 599
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 109
eloc 389
dl 0
loc 599
rs 2
c 0
b 0
f 0

28 Methods

Rating   Name   Duplication   Size   Complexity  
A Transform._use_seed() 0 8 1
A Transform.add_base_args() 0 11 4
A Transform.ones() 0 3 1
A Transform.get_sitk_interpolator() 0 3 1
A Transform.validate_keys_sequence() 0 11 4
A Transform.sitk_to_nib() 0 3 1
A Transform._get_reproducing_arguments() 0 11 1
A Transform.mean() 0 4 1
A Transform.parse_include_and_exclude_keys() 0 12 3
A Transform.to_range() 0 6 2
A Transform.get_base_args() 0 18 1
A Transform.add_transform_to_subject_history() 0 20 2
A Transform.parse_probability() 0 7 3
A Transform.parse_interpolation() 0 16 4
A Transform.__init__() 0 32 2
C Transform.parse_params() 0 19 9
C Transform.parse_bounds() 0 31 9
A Transform.__repr__() 0 10 4
C Transform.__call__() 0 49 11
A Transform.nib_to_sitk() 0 3 1
A Transform.get_mask_from_bounds() 0 14 1
B Transform.get_mask_from_anatomical_label() 0 28 8
A Transform.inverse() 0 6 2
F Transform._parse_range() 0 106 21
A Transform.name() 0 3 1
A Transform.is_invertible() 0 2 1
A Transform.apply_transform() 0 3 1
C Transform.get_mask_from_masking_method() 0 41 9

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.transform often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from __future__ import annotations
2
3
import copy
4
import numbers
5
import warnings
6
from abc import ABC
7
from abc import abstractmethod
8
from collections.abc import Sequence
9
from contextlib import contextmanager
10
from typing import TypeVar
11
from typing import Union
12
13
import numpy as np
14
import SimpleITK as sitk
15
import torch
16
17
from ..data.image import LabelMap
18
from ..data.io import nib_to_sitk
19
from ..data.io import sitk_to_nib
20
from ..data.subject import Subject
21
from ..types import TypeCallable
22
from ..types import TypeData
23
from ..types import TypeDataAffine
24
from ..types import TypeKeys
25
from ..types import TypeNumber
26
from ..types import TypeTripletInt
27
from ..utils import is_iterable
28
from ..utils import to_tuple
29
from .data_parser import DataParser
30
from .data_parser import TypeTransformInput
31
from .interpolation import Interpolation
32
from .interpolation import get_sitk_interpolator
33
34
TypeSixBounds = tuple[int, int, int, int, int, int]
35
TypeBounds = Union[int, TypeTripletInt, TypeSixBounds, None]
36
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
37
ANATOMICAL_AXES = (
38
    'Left',
39
    'Right',
40
    'Posterior',
41
    'Anterior',
42
    'Inferior',
43
    'Superior',
44
)
45
46
InputType = TypeVar('InputType', bound=TypeTransformInput)
47
48
49
class Transform(ABC):
50
    """Abstract class for all TorchIO transforms.
51
52
    When called, the input can be an instance of
53
    :class:`torchio.Subject`,
54
    :class:`torchio.Image`,
55
    :class:`numpy.ndarray`,
56
    :class:`torch.Tensor`,
57
    :class:`SimpleITK.Image`,
58
    or :class:`dict` containing 4D tensors as values.
59
60
    All subclasses must overwrite
61
    :meth:`~torchio.transforms.Transform.apply_transform`,
62
    which takes an instance of :class:`~torchio.Subject`,
63
    modifies it and returns the result.
64
65
    Args:
66
        p: Probability that this transform will be applied.
67
        copy: Make a deep copy of the input before applying the transform.
68
        include: Sequence of strings with the names of the only images to which
69
            the transform will be applied.
70
            Mandatory if the input is a :class:`dict`.
71
        exclude: Sequence of strings with the names of the images to which the
72
            the transform will not be applied, apart from the ones that are
73
            excluded because of the transform type.
74
            For example, if a subject includes an MRI, a CT and a label map,
75
            and the CT is added to the list of exclusions of an intensity
76
            transform such as :class:`~torchio.transforms.RandomBlur`,
77
            the transform will be only applied to the MRI, as the label map is
78
            excluded by default by spatial transforms.
79
        keep: Dictionary with the names of the input images that will be kept
80
            in the output and their new names. For example:
81
            ``{'t1': 't1_original'}``. This might be useful for autoencoders
82
            or registration tasks.
83
        parse_input: If ``True``, the input will be converted to an instance of
84
            :class:`~torchio.Subject`. This is used internally by some special
85
            transforms like
86
            :class:`~torchio.transforms.augmentation.composition.Compose`.
87
        label_keys: If the input is a dictionary, names of images that
88
            correspond to label maps.
89
    """
90
91
    def __init__(
92
        self,
93
        p: float = 1,
94
        copy: bool = True,
95
        include: TypeKeys = None,
96
        exclude: TypeKeys = None,
97
        keys: TypeKeys = None,
98
        keep: dict[str, str] | None = None,
99
        parse_input: bool = True,
100
        label_keys: TypeKeys = None,
101
    ):
102
        self.probability = self.parse_probability(p)
103
        self.copy = copy
104
        if keys is not None:
105
            message = (
106
                'The "keys" argument is deprecated and will be removed in the'
107
                ' future. Use "include" instead'
108
            )
109
            warnings.warn(message, FutureWarning, stacklevel=2)
110
            include = keys
111
        self.include, self.exclude = self.parse_include_and_exclude_keys(
112
            include,
113
            exclude,
114
            label_keys,
115
        )
116
        self.keep = keep
117
        self.parse_input = parse_input
118
        self.label_keys = label_keys
119
        # args_names is the sequence of parameters from self that need to be
120
        # passed to a non-random version of a random transform. They are also
121
        # used to invert invertible transforms
122
        self.args_names: list[str] = []
123
124
    def __call__(self, data: InputType) -> InputType:
125
        """Transform data and return a result of the same type.
126
127
        Args:
128
            data: Instance of :class:`torchio.Subject`, 4D
129
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
130
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
131
                and :math:`W, H, D` are the spatial dimensions. If the input is
132
                a tensor, the affine matrix will be set to identity. Other
133
                valid input types are a SimpleITK image, a
134
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
135
                :class:`dict`. The output type is the same as the input type.
136
        """
137
        if torch.rand(1).item() > self.probability:
138
            return data
139
140
        # Some transforms such as Compose should not modify the input data
141
        if self.parse_input:
142
            data_parser = DataParser(
143
                data,
144
                keys=self.include,
145
                label_keys=self.label_keys,
146
            )
147
            subject = data_parser.get_subject()
148
        else:
149
            subject = data
150
151
        if self.keep is not None:
152
            images_to_keep = {}
153
            for name, new_name in self.keep.items():
154
                images_to_keep[new_name] = copy.deepcopy(subject[name])
155
        if self.copy:
156
            subject = copy.deepcopy(subject)
157
        with np.errstate(all='raise', under='ignore'):
158
            transformed = self.apply_transform(subject)
159
        if self.keep is not None:
160
            for name, image in images_to_keep.items():
0 ignored issues
show
introduced by
The variable images_to_keep does not seem to be defined in case self.keep is not None on line 151 is False. Are you sure this can never be the case?
Loading history...
161
                transformed.add_image(image, name)
162
163
        if self.parse_input:
164
            self.add_transform_to_subject_history(transformed)
165
            for image in transformed.get_images(intensity_only=False):
166
                ndim = image.data.ndim
167
                assert ndim == 4, f'Output of {self.name} is {ndim}D'
168
            output = data_parser.get_output(transformed)
0 ignored issues
show
introduced by
The variable data_parser does not seem to be defined for all execution paths.
Loading history...
169
        else:
170
            output = transformed
171
172
        return output
173
174
    def __repr__(self):
175
        if hasattr(self, 'args_names'):
176
            names = self.args_names
177
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
178
            if hasattr(self, 'invert_transform') and self.invert_transform:
179
                args_strings.append('invert=True')
180
            args_string = ', '.join(args_strings)
181
            return f'{self.name}({args_string})'
182
        else:
183
            return super().__repr__()
184
185
    def get_base_args(self) -> dict:
186
        r"""Provides easy access to the arguments used to instantiate the base class
187
        (:class:`~torchio.transforms.transform.Transform`) of any transform.
188
189
        This method is particularly useful when a new transform can be represented as a variant
190
        of an existing transform (e.g. all random transforms), allowing for seamless instantiation
191
        of the existing transform with the same arguments as the new transform during `apply_transform`.
192
193
        Note: The `p` argument (probability of applying the transform) is excluded to avoid
194
        multiplying the probability of both existing and new transform.
195
        """
196
        return {
197
            'copy': self.copy,
198
            'include': self.include,
199
            'exclude': self.exclude,
200
            'keep': self.keep,
201
            'parse_input': self.parse_input,
202
            'label_keys': self.label_keys,
203
        }
204
205
    def add_base_args(
206
        self,
207
        arguments,
208
        overwrite_on_existing: bool = False,
209
    ):
210
        """Add the init args to existing arguments"""
211
        for key, value in self.get_base_args().items():
212
            if key in arguments and not overwrite_on_existing:
213
                continue
214
            arguments[key] = value
215
        return arguments
216
217
    @property
218
    def name(self):
219
        return self.__class__.__name__
220
221
    @abstractmethod
222
    def apply_transform(self, subject: Subject) -> Subject:
223
        raise NotImplementedError
224
225
    def add_transform_to_subject_history(self, subject):
226
        from . import Compose
227
        from . import CropOrPad
228
        from . import EnsureShapeMultiple
229
        from . import OneOf
230
        from .augmentation import RandomTransform
231
        from .preprocessing import Resize
232
        from .preprocessing import SequentialLabels
233
234
        call_others = (
235
            RandomTransform,
236
            Compose,
237
            OneOf,
238
            CropOrPad,
239
            EnsureShapeMultiple,
240
            SequentialLabels,
241
            Resize,
242
        )
243
        if not isinstance(self, call_others):
244
            subject.add_transform(self, self._get_reproducing_arguments())
245
246
    @staticmethod
247
    def to_range(n, around):
248
        if around is None:
249
            return 0, n
250
        else:
251
            return around - n, around + n
252
253
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
254
        params = to_tuple(params)
255
        # d or (a, b)
256
        if len(params) == 1 or (len(params) == 2 and make_ranges):
257
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
258
        if len(params) == 3 and make_ranges:  # (a, b, c)
259
            items = [self.to_range(n, around) for n in params]
260
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
261
            params = [n for prange in items for n in prange]
262
        if make_ranges:
263
            if len(params) != 6:
264
                message = (
265
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
266
                    f' 6, not {len(params)}'
267
                )
268
                raise ValueError(message)
269
            for param_range in zip(params[::2], params[1::2]):
270
                self._parse_range(param_range, name, **kwargs)
271
        return tuple(params)
272
273
    @staticmethod
274
    def _parse_range(
275
        nums_range: TypeNumber | tuple[TypeNumber, TypeNumber],
276
        name: str,
277
        min_constraint: TypeNumber | None = None,
278
        max_constraint: TypeNumber | None = None,
279
        type_constraint: type | None = None,
280
    ) -> tuple[TypeNumber, TypeNumber]:
281
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
282
283
        Args:
284
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
285
                where :math:`n_{min} \leq n_{max}`.
286
                If a single positive number :math:`n` is provided,
287
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
288
            name: Name of the parameter, so that an informative error message
289
                can be printed.
290
            min_constraint: Minimal value that :math:`n_{min}` can take,
291
                default is None, i.e. there is no minimal value.
292
            max_constraint: Maximal value that :math:`n_{max}` can take,
293
                default is None, i.e. there is no maximal value.
294
            type_constraint: Precise type that :math:`n_{max}` and
295
                :math:`n_{min}` must take.
296
297
        Returns:
298
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
299
300
        Raises:
301
            ValueError: if :attr:`nums_range` is negative
302
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
303
            ValueError: if :math:`n_{max} \lt n_{min}`
304
            ValueError: if :attr:`min_constraint` is not None and
305
                :math:`n_{min}` is smaller than :attr:`min_constraint`
306
            ValueError: if :attr:`max_constraint` is not None and
307
                :math:`n_{max}` is greater than :attr:`max_constraint`
308
            ValueError: if :attr:`type_constraint` is not None and
309
                :math:`n_{max}` and :math:`n_{max}` are not of type
310
                :attr:`type_constraint`.
311
        """
312
        if isinstance(nums_range, numbers.Number):  # single number given
313
            if nums_range < 0:
314
                raise ValueError(
315
                    f'If {name} is a single number,'
316
                    f' it must be positive, not {nums_range}',
317
                )
318
            if min_constraint is not None and nums_range < min_constraint:
319
                raise ValueError(
320
                    f'If {name} is a single number, it must be greater'
321
                    f' than {min_constraint}, not {nums_range}',
322
                )
323
            if max_constraint is not None and nums_range > max_constraint:
324
                raise ValueError(
325
                    f'If {name} is a single number, it must be smaller'
326
                    f' than {max_constraint}, not {nums_range}',
327
                )
328
            if type_constraint is not None:
329
                if not isinstance(nums_range, type_constraint):
330
                    raise ValueError(
331
                        f'If {name} is a single number, it must be of'
332
                        f' type {type_constraint}, not {nums_range}',
333
                    )
334
            min_range = -nums_range if min_constraint is None else nums_range
335
            return (min_range, nums_range)
336
337
        try:
338
            min_value, max_value = nums_range  # type: ignore[misc]
339
        except (TypeError, ValueError) as err:
340
            message = (
341
                f'If {name} is not a single number, it must be'
342
                f' a sequence of len 2, not {nums_range}'
343
            )
344
            raise ValueError(message) from err
345
346
        min_is_number = isinstance(min_value, numbers.Number)
347
        max_is_number = isinstance(max_value, numbers.Number)
348
        if not min_is_number or not max_is_number:
349
            message = f'{name} values must be numbers, not {nums_range}'
350
            raise ValueError(message)
351
352
        if min_value > max_value:
353
            raise ValueError(
354
                f'If {name} is a sequence, the second value must be'
355
                f' equal or greater than the first, but it is {nums_range}',
356
            )
357
358
        if min_constraint is not None and min_value < min_constraint:
359
            raise ValueError(
360
                f'If {name} is a sequence, the first value must be greater'
361
                f' than {min_constraint}, but it is {min_value}',
362
            )
363
364
        if max_constraint is not None and max_value > max_constraint:
365
            raise ValueError(
366
                f'If {name} is a sequence, the second value must be'
367
                f' smaller than {max_constraint}, but it is {max_value}',
368
            )
369
370
        if type_constraint is not None:
371
            min_type_ok = isinstance(min_value, type_constraint)
372
            max_type_ok = isinstance(max_value, type_constraint)
373
            if not min_type_ok or not max_type_ok:
374
                raise ValueError(
375
                    f'If "{name}" is a sequence, its values must be of'
376
                    f' type "{type_constraint}", not "{type(nums_range)}"',
377
                )
378
        return nums_range  # type: ignore[return-value]
379
380
    @staticmethod
381
    def parse_interpolation(interpolation: str) -> str:
382
        if not isinstance(interpolation, str):
383
            itype = type(interpolation)
384
            raise TypeError(f'Interpolation must be a string, not {itype}')
385
        interpolation = interpolation.lower()
386
        is_string = isinstance(interpolation, str)
387
        supported_values = [key.name.lower() for key in Interpolation]
388
        is_supported = interpolation.lower() in supported_values
389
        if is_string and is_supported:
390
            return interpolation
391
        message = (
392
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
393
            f' must be a string among the supported values: {supported_values}'
394
        )
395
        raise ValueError(message)
396
397
    @staticmethod
398
    def parse_probability(probability: float) -> float:
399
        is_number = isinstance(probability, numbers.Number)
400
        if not (is_number and 0 <= probability <= 1):
401
            message = f'Probability must be a number in [0, 1], not {probability}'
402
            raise ValueError(message)
403
        return probability
404
405
    @staticmethod
406
    def parse_include_and_exclude_keys(
407
        include: TypeKeys,
408
        exclude: TypeKeys,
409
        label_keys: TypeKeys,
410
    ) -> tuple[TypeKeys, TypeKeys]:
411
        if include is not None and exclude is not None:
412
            raise ValueError('Include and exclude cannot both be specified')
413
        Transform.validate_keys_sequence(include, 'include')
414
        Transform.validate_keys_sequence(exclude, 'exclude')
415
        Transform.validate_keys_sequence(label_keys, 'label_keys')
416
        return include, exclude
417
418
    @staticmethod
419
    def validate_keys_sequence(keys: TypeKeys, name: str) -> None:
420
        """Ensure that the input is not a string but a sequence of strings."""
421
        if keys is None:
422
            return
423
        if isinstance(keys, str):
424
            message = f'"{name}" must be a sequence of strings, not a string "{keys}"'
425
            raise ValueError(message)
426
        if not is_iterable(keys):
427
            message = f'"{name}" must be a sequence of strings, not {type(keys)}'
428
            raise ValueError(message)
429
430
    @staticmethod
431
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
432
        return nib_to_sitk(data, affine)
433
434
    @staticmethod
435
    def sitk_to_nib(image: sitk.Image) -> TypeDataAffine:
436
        return sitk_to_nib(image)  # type: ignore[return-value]
437
438
    def _get_reproducing_arguments(self):
439
        """Return a dictionary with the arguments that would be necessary to
440
        reproduce the transform exactly."""
441
        reproducing_arguments = {
442
            'include': self.include,
443
            'exclude': self.exclude,
444
            'copy': self.copy,
445
        }
446
        args_names = {name: getattr(self, name) for name in self.args_names}
447
        reproducing_arguments.update(args_names)
448
        return reproducing_arguments
449
450
    def is_invertible(self):
451
        return hasattr(self, 'invert_transform')
452
453
    def inverse(self):
454
        if not self.is_invertible():
455
            raise RuntimeError(f'{self.name} is not invertible')
456
        new = copy.deepcopy(self)
457
        new.invert_transform = not self.invert_transform
458
        return new
459
460
    @staticmethod
461
    @contextmanager
462
    def _use_seed(seed):
463
        """Perform an operation using a specific seed for the PyTorch RNG."""
464
        torch_rng_state = torch.random.get_rng_state()
465
        torch.manual_seed(seed)
466
        yield
467
        torch.random.set_rng_state(torch_rng_state)
468
469
    @staticmethod
470
    def get_sitk_interpolator(interpolation: str) -> int:
471
        return get_sitk_interpolator(interpolation)
472
473
    @staticmethod
474
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds | None:
475
        if bounds_parameters is None:
476
            return None
477
        try:
478
            bounds_parameters = tuple(bounds_parameters)  # type: ignore[assignment,arg-type]
479
        except TypeError:
480
            bounds_parameters = (bounds_parameters,)  # type: ignore[assignment]
481
482
        # Check that numbers are integers
483
        for number in bounds_parameters:  # type: ignore[union-attr]
484
            if not isinstance(number, (int, np.integer)) or number < 0:
485
                message = (
486
                    'Bounds values must be integers greater or equal to zero,'
487
                    f' not "{bounds_parameters}" of type {type(number)}'
488
                )
489
                raise ValueError(message)
490
        bounds_parameters_tuple = tuple(int(n) for n in bounds_parameters)  # type: ignore[assignment,union-attr]
491
        bounds_parameters_length = len(bounds_parameters_tuple)
492
        if bounds_parameters_length == 6:
493
            return bounds_parameters_tuple  # type: ignore[return-value]
494
        if bounds_parameters_length == 1:
495
            return 6 * bounds_parameters_tuple  # type: ignore[return-value]
496
        if bounds_parameters_length == 3:
497
            repeat = np.repeat(bounds_parameters_tuple, 2).tolist()
498
            return tuple(repeat)  # type: ignore[return-value]
499
        message = (
500
            'Bounds parameter must be an integer or a tuple of'
501
            f' 3 or 6 integers, not {bounds_parameters_tuple}'
502
        )
503
        raise ValueError(message)
504
505
    @staticmethod
506
    def ones(tensor: torch.Tensor) -> torch.Tensor:
507
        return torch.ones_like(tensor, dtype=torch.bool)
508
509
    @staticmethod
510
    def mean(tensor: torch.Tensor) -> torch.Tensor:
511
        mask = tensor > tensor.float().mean()
512
        return mask
513
514
    def get_mask_from_masking_method(
515
        self,
516
        masking_method: TypeMaskingMethod,
517
        subject: Subject,
518
        tensor: torch.Tensor,
519
        labels: Sequence[int] | None = None,
520
    ) -> torch.Tensor:
521
        if masking_method is None:
522
            return self.ones(tensor)
523
        elif callable(masking_method):
524
            return masking_method(tensor)
525
        elif type(masking_method) is str:
526
            in_subject = masking_method in subject
527
            if in_subject and isinstance(subject[masking_method], LabelMap):
528
                if labels is None:
529
                    return subject[masking_method].data.bool()
530
                else:
531
                    mask_data = subject[masking_method].data
532
                    volumes = [mask_data == label for label in labels]
533
                    return torch.stack(volumes).sum(0).bool()
534
            possible_axis = masking_method.capitalize()
535
            if possible_axis in ANATOMICAL_AXES:
536
                return self.get_mask_from_anatomical_label(
537
                    possible_axis,
538
                    tensor,
539
                )
540
        elif type(masking_method) in (tuple, list, int):
541
            return self.get_mask_from_bounds(masking_method, tensor)  # type: ignore[arg-type]
542
        first_anat_axes = tuple(s[0] for s in ANATOMICAL_AXES)
543
        message = (
544
            'Masking method must be one of:\n'
545
            ' 1) A callable object, such as a function\n'
546
            ' 2) The name of a label map in the subject'
547
            f' ({subject.get_images_names()})\n'
548
            f' 3) An anatomical label {ANATOMICAL_AXES + first_anat_axes}\n'
549
            ' 4) A bounds parameter'
550
            ' (int, tuple of 3 ints, or tuple of 6 ints)\n'
551
            f' The passed value, "{masking_method}",'
552
            f' of type "{type(masking_method)}", is not valid'
553
        )
554
        raise ValueError(message)
555
556
    @staticmethod
557
    def get_mask_from_anatomical_label(
558
        anatomical_label: str,
559
        tensor: torch.Tensor,
560
    ) -> torch.Tensor:
561
        # Assume the image is in RAS orientation
562
        anatomical_label = anatomical_label.capitalize()
563
        if anatomical_label not in ANATOMICAL_AXES:
564
            message = (
565
                f'Anatomical label must be one of {ANATOMICAL_AXES}'
566
                f' not {anatomical_label}'
567
            )
568
            raise ValueError(message)
569
        mask = torch.zeros_like(tensor, dtype=torch.bool)
570
        _, width, height, depth = tensor.shape
571
        if anatomical_label == 'Right':
572
            mask[:, width // 2 :] = True
573
        elif anatomical_label == 'Left':
574
            mask[:, : width // 2] = True
575
        elif anatomical_label == 'Anterior':
576
            mask[:, :, height // 2 :] = True
577
        elif anatomical_label == 'Posterior':
578
            mask[:, :, : height // 2] = True
579
        elif anatomical_label == 'Superior':
580
            mask[:, :, :, depth // 2 :] = True
581
        elif anatomical_label == 'Inferior':
582
            mask[:, :, :, : depth // 2] = True
583
        return mask
584
585
    def get_mask_from_bounds(
586
        self,
587
        bounds_parameters: TypeBounds,
588
        tensor: torch.Tensor,
589
    ) -> torch.Tensor:
590
        bounds_parameters = self.parse_bounds(bounds_parameters)
591
        assert bounds_parameters is not None
592
        low = bounds_parameters[::2]
593
        high = bounds_parameters[1::2]
594
        i0, j0, k0 = low
595
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
596
        mask = torch.zeros_like(tensor, dtype=torch.bool)
597
        mask[:, i0:i1, j0:j1, k0:k1] = True
598
        return mask
599