| 1 |  |  | from __future__ import annotations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import copy | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import pprint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from collections.abc import Sequence | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from typing import TYPE_CHECKING | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from typing import Any | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from ..constants import INTENSITY | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from ..constants import TYPE | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | from ..utils import get_subclasses | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from .image import Image | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | if TYPE_CHECKING: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     from ..transforms import Compose | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |     from ..transforms import Transform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  | class Subject(dict): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     """Class to store information about the images corresponding to a subject. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         *args: If provided, a dictionary of items. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         **kwargs: Items that will be added to the subject sample. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |     Example: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         >>> import torchio as tio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         >>> # One way: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         >>> subject = tio.Subject( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         ...     one_image=tio.ScalarImage('path_to_image.nii.gz'), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |         ...     age=45, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         ...     name='John Doe', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |         ...     hospital='Hospital Juan Negrín', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         ... ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         >>> # If you want to create the mapping before, or have spaces in the keys: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         >>> subject_dict = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         ...     'one image': tio.ScalarImage('path_to_image.nii.gz'), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         ...     'age': 45, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         ...     'name': 'John Doe', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         ...     'hospital': 'Hospital Juan Negrín', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         ... } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         >>> subject = tio.Subject(subject_dict) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |     def __init__(self, *args, **kwargs: dict[str, Any]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         if args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |             if len(args) == 1 and isinstance(args[0], dict): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |                 kwargs.update(args[0]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                 message = 'Only one dictionary as positional argument is allowed' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |                 raise ValueError(message) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         super().__init__(**kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         self._parse_images(self.get_images(intensity_only=False)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |         self.update_attributes()  # this allows me to do e.g. subject.t1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         self.applied_transforms: list[tuple[str, dict]] = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |     def __repr__(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         num_images = len(self.get_images(intensity_only=False)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         string = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |             f'{self.__class__.__name__}' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |             f'(Keys: {tuple(self.keys())}; images: {num_images})' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         return string | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     def __len__(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         return len(self.get_images(intensity_only=False)) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 73 |  |  |     def __getitem__(self, item): | 
            
                                                                        
                            
            
                                    
            
            
                | 74 |  |  |         if isinstance(item, (slice, int, tuple)): | 
            
                                                                        
                            
            
                                    
            
            
                | 75 |  |  |             try: | 
            
                                                                        
                            
            
                                    
            
            
                | 76 |  |  |                 self.check_consistent_spatial_shape() | 
            
                                                                        
                            
            
                                    
            
            
                | 77 |  |  |             except RuntimeError as e: | 
            
                                                                        
                            
            
                                    
            
            
                | 78 |  |  |                 message = ( | 
            
                                                                        
                            
            
                                    
            
            
                | 79 |  |  |                     'To use indexing, all images in the subject must have the' | 
            
                                                                        
                            
            
                                    
            
            
                | 80 |  |  |                     ' same spatial shape' | 
            
                                                                        
                            
            
                                    
            
            
                | 81 |  |  |                 ) | 
            
                                                                        
                            
            
                                    
            
            
                | 82 |  |  |                 raise RuntimeError(message) from e | 
            
                                                                        
                            
            
                                    
            
            
                | 83 |  |  |             copied = copy.deepcopy(self) | 
            
                                                                        
                            
            
                                    
            
            
                | 84 |  |  |             for image_name, image in copied.items(): | 
            
                                                                        
                            
            
                                    
            
            
                | 85 |  |  |                 copied[image_name] = image[item] | 
            
                                                                        
                            
            
                                    
            
            
                | 86 |  |  |             return copied | 
            
                                                                        
                            
            
                                    
            
            
                | 87 |  |  |         else: | 
            
                                                                        
                            
            
                                    
            
            
                | 88 |  |  |             return super().__getitem__(item) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |     def _parse_images(images: list[Image]) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         # Check that it's not empty | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         if not images: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |             raise TypeError('A subject without images cannot be created') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |     def shape(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         """Return shape of first image in subject. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         Consistency of shapes across images in the subject is checked first. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         Example: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |             >>> import torchio as tio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |             >>> colin = tio.datasets.Colin27() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |             >>> colin.shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |             (1, 181, 217, 181) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         self.check_consistent_attribute('shape') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         return self.get_first_image().shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |     def spatial_shape(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         """Return spatial shape of first image in subject. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         Consistency of spatial shapes across images in the subject is checked | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         first. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         Example: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |             >>> import torchio as tio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |             >>> colin = tio.datasets.Colin27() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |             >>> colin.spatial_shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |             (181, 217, 181) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         self.check_consistent_spatial_shape() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         return self.get_first_image().spatial_shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |     def spacing(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         """Return spacing of first image in subject. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         Consistency of spacings across images in the subject is checked first. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         Example: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |             >>> import torchio as tio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |             >>> colin = tio.datasets.Slicer() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |             >>> colin.spacing | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |             (1.0, 1.0, 1.2999954223632812) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         self.check_consistent_attribute('spacing') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |         return self.get_first_image().spacing | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |     def history(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |         # Kept for backwards compatibility | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         return self.get_applied_transforms() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |     def is_2d(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         return all(i.is_2d() for i in self.get_images(intensity_only=False)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |     def get_applied_transforms( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         ignore_intensity: bool = False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |         image_interpolation: str | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |     ) -> list[Transform]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         from ..transforms.intensity_transform import IntensityTransform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         from ..transforms.transform import Transform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         name_to_transform = {cls.__name__: cls for cls in get_subclasses(Transform)} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         transforms_list = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         for transform_name, arguments in self.applied_transforms: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |             transform = name_to_transform[transform_name](**arguments) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |             if ignore_intensity and isinstance(transform, IntensityTransform): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |             resamples = hasattr(transform, 'image_interpolation') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |             if resamples and image_interpolation is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |                 parsed = transform.parse_interpolation(image_interpolation) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |                 transform.image_interpolation = parsed | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |             transforms_list.append(transform) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |         return transforms_list | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |     def get_composed_history( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |         ignore_intensity: bool = False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |         image_interpolation: str | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |     ) -> Compose: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         from ..transforms.augmentation.composition import Compose | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |         transforms = self.get_applied_transforms( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |             ignore_intensity=ignore_intensity, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |             image_interpolation=image_interpolation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         return Compose(transforms) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |     def get_inverse_transform( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |         warn: bool = True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |         ignore_intensity: bool = False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |         image_interpolation: str | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |     ) -> Compose: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |         """Get a reversed list of the inverses of the applied transforms. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |         Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |             warn: Issue a warning if some transforms are not invertible. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |             ignore_intensity: If ``True``, all instances of | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |                 :class:`~torchio.transforms.intensity_transform.IntensityTransform` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |                 will be ignored. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |             image_interpolation: Modify interpolation for scalar images inside | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |                 transforms that perform resampling. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |         history_transform = self.get_composed_history( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |             ignore_intensity=ignore_intensity, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |             image_interpolation=image_interpolation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |         inverse_transform = history_transform.inverse(warn=warn) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |         return inverse_transform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |     def apply_inverse_transform(self, **kwargs) -> Subject: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         """Apply the inverse of all applied transforms, in reverse order. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |         Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |             **kwargs: Keyword arguments passed on to | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |                 :meth:`~torchio.data.subject.Subject.get_inverse_transform`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |         inverse_transform = self.get_inverse_transform(**kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         transformed: Subject | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |         transformed = inverse_transform(self)  # type: ignore[assignment] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |         transformed.clear_history() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |         return transformed | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |     def clear_history(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         self.applied_transforms = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |     def check_consistent_attribute( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |         attribute: str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |         relative_tolerance: float = 1e-6, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |         absolute_tolerance: float = 1e-6, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |         message: str | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |     ) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |         r"""Check for consistency of an attribute across all images. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |         Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |             attribute: Name of the image attribute to check | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |             relative_tolerance: Relative tolerance for :func:`numpy.allclose()` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |             absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |         Example: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |             >>> import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |             >>> import torch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |             >>> import torchio as tio | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |             >>> scalars = torch.randn(1, 512, 512, 100) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |             >>> mask = torch.tensor(scalars > 0).type(torch.int16) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |             >>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |             >>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1])  # small difference here (e.g. due to different reader) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |             >>> subject = tio.Subject( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |             ...   image = tio.ScalarImage(tensor=scalars, affine=af1), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |             ...   mask = tio.LabelMap(tensor=mask, affine=af2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |             ... ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |             >>> subject.check_consistent_attribute('spacing')  # no error as tolerances are > 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |         .. note:: To check that all values for a specific attribute are close | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |             between all images in the subject, :func:`numpy.allclose()` is used. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |             This function returns ``True`` if | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |             :math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |             :math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  |             attribute of two images being compared, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |             :math:`t_{abs}` is the ``absolute_tolerance`` and | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |             :math:`t_{rel}` is the ``relative_tolerance``. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |         message = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |             f'More than one value for "{attribute}" found in subject images:\n{{}}' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |         names_images = self.get_images_dict(intensity_only=False).items() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |             first_attribute = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |             first_image = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  |             for image_name, image in names_images: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |                 if first_attribute is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  |                     first_attribute = getattr(image, attribute) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |                     first_image = image_name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  |                     continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |                 current_attribute = getattr(image, attribute) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 |  |  |                 all_close = np.allclose( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 279 |  |  |                     current_attribute, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 280 |  |  |                     first_attribute, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 281 |  |  |                     rtol=relative_tolerance, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 282 |  |  |                     atol=absolute_tolerance, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 283 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 284 |  |  |                 if not all_close: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 285 |  |  |                     message = message.format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 286 |  |  |                         pprint.pformat( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 287 |  |  |                             { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 288 |  |  |                                 first_image: first_attribute, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 289 |  |  |                                 image_name: current_attribute, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 290 |  |  |                             } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 291 |  |  |                         ), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 292 |  |  |                     ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 293 |  |  |                     raise RuntimeError(message) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 294 |  |  |         except TypeError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 295 |  |  |             # fallback for non-numeric values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 296 |  |  |             values_dict = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 297 |  |  |             for image_name, image in names_images: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 298 |  |  |                 values_dict[image_name] = getattr(image, attribute) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 299 |  |  |             num_unique_values = len(set(values_dict.values())) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 300 |  |  |             if num_unique_values > 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 301 |  |  |                 message = message.format(pprint.pformat(values_dict)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 302 |  |  |                 raise RuntimeError(message) from None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 303 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 304 |  |  |     def check_consistent_spatial_shape(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 305 |  |  |         self.check_consistent_attribute('spatial_shape') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 306 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 307 |  |  |     def check_consistent_orientation(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 308 |  |  |         self.check_consistent_attribute('orientation') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 309 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 310 |  |  |     def check_consistent_affine(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 311 |  |  |         self.check_consistent_attribute('affine') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 312 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 313 |  |  |     def check_consistent_space(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 314 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 315 |  |  |             self.check_consistent_attribute('spacing') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 316 |  |  |             self.check_consistent_attribute('direction') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 317 |  |  |             self.check_consistent_attribute('origin') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 318 |  |  |             self.check_consistent_spatial_shape() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 319 |  |  |         except RuntimeError as e: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 320 |  |  |             message = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 321 |  |  |                 'As described above, some images in the subject are not in the' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 322 |  |  |                 ' same space. You probably can use the transforms ToCanonical' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 323 |  |  |                 ' and Resample to fix this, as explained at' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 324 |  |  |                 ' https://github.com/TorchIO-project/torchio/issues/647#issuecomment-913025695' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 325 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 326 |  |  |             raise RuntimeError(message) from e | 
            
                                                                                                            
                            
            
                                    
            
            
                | 327 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 328 |  |  |     def get_images_names(self) -> list[str]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 329 |  |  |         return list(self.get_images_dict(intensity_only=False).keys()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 330 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 331 |  |  |     def get_images_dict( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 332 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 333 |  |  |         intensity_only=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 334 |  |  |         include: Sequence[str] | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 335 |  |  |         exclude: Sequence[str] | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 336 |  |  |     ) -> dict[str, Image]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 337 |  |  |         images = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 338 |  |  |         for image_name, image in self.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 339 |  |  |             if not isinstance(image, Image): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 340 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 341 |  |  |             if intensity_only and not image[TYPE] == INTENSITY: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 342 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 343 |  |  |             if include is not None and image_name not in include: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 344 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 345 |  |  |             if exclude is not None and image_name in exclude: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 346 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 347 |  |  |             images[image_name] = image | 
            
                                                                                                            
                            
            
                                    
            
            
                | 348 |  |  |         return images | 
            
                                                                                                            
                            
            
                                    
            
            
                | 349 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 350 |  |  |     def get_images( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 351 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 352 |  |  |         intensity_only=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 353 |  |  |         include: Sequence[str] | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 354 |  |  |         exclude: Sequence[str] | None = None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 355 |  |  |     ) -> list[Image]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 356 |  |  |         images_dict = self.get_images_dict( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 357 |  |  |             intensity_only=intensity_only, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 358 |  |  |             include=include, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 359 |  |  |             exclude=exclude, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 360 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 361 |  |  |         return list(images_dict.values()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 362 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 363 |  |  |     def get_image(self, image_name: str) -> Image: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 364 |  |  |         """Get a single image by its name.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 365 |  |  |         return self.get_images_dict(intensity_only=False)[image_name] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 366 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 367 |  |  |     def get_first_image(self) -> Image: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 368 |  |  |         return self.get_images(intensity_only=False)[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 369 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 370 |  |  |     def add_transform( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 371 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 372 |  |  |         transform: Transform, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 373 |  |  |         parameters_dict: dict, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 374 |  |  |     ) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 375 |  |  |         self.applied_transforms.append((transform.name, parameters_dict)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 376 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 377 |  |  |     def load(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 378 |  |  |         """Load images in subject on RAM.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 379 |  |  |         for image in self.get_images(intensity_only=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 380 |  |  |             image.load() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 381 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 382 |  |  |     def unload(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 383 |  |  |         """Unload images in subject.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 384 |  |  |         for image in self.get_images(intensity_only=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 385 |  |  |             image.unload() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 386 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 387 |  |  |     def update_attributes(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 388 |  |  |         # This allows to get images using attribute notation, e.g. subject.t1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 389 |  |  |         self.__dict__.update(self) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 390 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 391 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 392 |  |  |     def _check_image_name(image_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 393 |  |  |         if not isinstance(image_name, str): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 394 |  |  |             message = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 395 |  |  |                 f'The image name must be a string, but it has type "{type(image_name)}"' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 396 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 397 |  |  |             raise ValueError(message) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 398 |  |  |         return image_name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 399 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 400 |  |  |     def add_image(self, image: Image, image_name: str) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 401 |  |  |         """Add an image to the subject instance.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 402 |  |  |         if not isinstance(image, Image): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 403 |  |  |             message = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 404 |  |  |                 'Image must be an instance of torchio.Image,' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 405 |  |  |                 f' but its type is "{type(image)}"' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 406 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 407 |  |  |             raise ValueError(message) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 408 |  |  |         self._check_image_name(image_name) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 409 |  |  |         self[image_name] = image | 
            
                                                                                                            
                            
            
                                    
            
            
                | 410 |  |  |         self.update_attributes() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 411 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 412 |  |  |     def remove_image(self, image_name: str) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 413 |  |  |         """Remove an image from the subject instance.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 414 |  |  |         self._check_image_name(image_name) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 415 |  |  |         del self[image_name] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 416 |  |  |         delattr(self, image_name) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 417 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 418 |  |  |     def plot(self, **kwargs) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 419 |  |  |         """Plot images using matplotlib. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 420 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 421 |  |  |         Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 422 |  |  |             **kwargs: Keyword arguments that will be passed on to | 
            
                                                                                                            
                            
            
                                    
            
            
                | 423 |  |  |                 :meth:`~torchio.Image.plot`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 424 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 425 |  |  |         from ..visualization import plot_subject  # avoid circular import | 
            
                                                                                                            
                            
            
                                    
            
            
                | 426 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 427 |  |  |         plot_subject(self, **kwargs) | 
            
                                                        
            
                                    
            
            
                | 428 |  |  |  |