| 1 |  |  | import warnings | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from numbers import Number | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | from typing import Union | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import torch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from nibabel.affines import apply_affine | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from ....data.image import LabelMap | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from ....data.subject import Subject | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from .bounds_transform import BoundsTransform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from .bounds_transform import TypeBounds | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | class Pad(BoundsTransform): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     r"""Pad an image. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |     Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         padding: Tuple | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |             :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |             defining the number of values padded to the edges of each axis. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |             If the initial shape of the image is | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |             :math:`W \times H \times D`, the final shape will be | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |             :math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |             \times (d_{ini} + D + d_{fin})`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |             If only three values :math:`(w, h, d)` are provided, then | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |             :math:`w_{ini} = w_{fin} = w`, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |             :math:`h_{ini} = h_{fin} = h` and | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |             :math:`d_{ini} = d_{fin} = d`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |             If only one value :math:`n` is provided, then | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |             :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |             d_{ini} = d_{fin} = n`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         padding_mode: See possible modes in `NumPy docs`_. If it is a number, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |             the mode will be set to ``'constant'``. If it is ``'mean'``, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |             ``'maximum'``, ``'median'`` or ``'minimum'``, the statistic will be | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |             computed from the whole volume, unlike in NumPy, which computes it | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |             along the padded axis. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         **kwargs: See :class:`~torchio.transforms.Transform` for additional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |             keyword arguments. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     .. seealso:: If you want to pass the output shape instead, please use | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         :class:`~torchio.transforms.CropOrPad` instead. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     .. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |     PADDING_MODES = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         'empty', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         'edge', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         'wrap', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         'constant', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         'linear_ramp', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         'maximum', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |         'mean', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |         'median', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |         'minimum', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         'reflect', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         'symmetric', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |     def __init__( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         padding: TypeBounds, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         padding_mode: Union[str, float] = 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         **kwargs, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |     ): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         super().__init__(padding, **kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         self.padding = padding | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         self.check_padding_mode(padding_mode) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         self.padding_mode = padding_mode | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         self.args_names = ['padding', 'padding_mode'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |     def check_padding_mode(cls, padding_mode): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         is_number = isinstance(padding_mode, Number) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         is_callable = callable(padding_mode) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |         if not (padding_mode in cls.PADDING_MODES or is_number or is_callable): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             message = ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |                 f'Padding mode "{padding_mode}" not valid. Valid options are' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |                 f' {list(cls.PADDING_MODES)}, a number or a function' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             raise KeyError(message) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 84 |  |  |     def apply_transform(self, subject: Subject) -> Subject: | 
            
                                                                        
                            
            
                                    
            
            
                | 85 |  |  |         assert self.bounds_parameters is not None | 
            
                                                                        
                            
            
                                    
            
            
                | 86 |  |  |         low = self.bounds_parameters[::2] | 
            
                                                                        
                            
            
                                    
            
            
                | 87 |  |  |         for image in self.get_images(subject): | 
            
                                                                        
                            
            
                                    
            
            
                | 88 |  |  |             if isinstance(image, LabelMap) and self.padding_mode == 'mean': | 
            
                                                                        
                            
            
                                    
            
            
                | 89 |  |  |                 message = ( | 
            
                                                                        
                            
            
                                    
            
            
                | 90 |  |  |                     'Padding mode "mean" might create non-integer values in label maps' | 
            
                                                                        
                            
            
                                    
            
            
                | 91 |  |  |                 ) | 
            
                                                                        
                            
            
                                    
            
            
                | 92 |  |  |                 warnings.warn(message, RuntimeWarning, stacklevel=2) | 
            
                                                                        
                            
            
                                    
            
            
                | 93 |  |  |             new_origin = apply_affine(image.affine, -np.array(low)) | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |             new_affine = image.affine.copy() | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  |             new_affine[:3, 3] = new_origin | 
            
                                                                        
                            
            
                                    
            
            
                | 96 |  |  |             mode: str | float = 'constant' | 
            
                                                                        
                            
            
                                    
            
            
                | 97 |  |  |             constant: torch.Tensor | float | None = None | 
            
                                                                        
                            
            
                                    
            
            
                | 98 |  |  |             if isinstance(self.padding_mode, Number): | 
            
                                                                        
                            
            
                                    
            
            
                | 99 |  |  |                 constant = self.padding_mode  # type: ignore[assignment] | 
            
                                                                        
                            
            
                                    
            
            
                | 100 |  |  |             elif self.padding_mode == 'maximum': | 
            
                                                                        
                            
            
                                    
            
            
                | 101 |  |  |                 constant = image.data.max() | 
            
                                                                        
                            
            
                                    
            
            
                | 102 |  |  |             elif self.padding_mode == 'mean': | 
            
                                                                        
                            
            
                                    
            
            
                | 103 |  |  |                 constant = image.data.float().mean() | 
            
                                                                        
                            
            
                                    
            
            
                | 104 |  |  |             elif self.padding_mode == 'median': | 
            
                                                                        
                            
            
                                    
            
            
                | 105 |  |  |                 constant = torch.quantile(image.data.float(), 0.5) | 
            
                                                                        
                            
            
                                    
            
            
                | 106 |  |  |             elif self.padding_mode == 'minimum': | 
            
                                                                        
                            
            
                                    
            
            
                | 107 |  |  |                 constant = image.data.min() | 
            
                                                                        
                            
            
                                    
            
            
                | 108 |  |  |             else: | 
            
                                                                        
                            
            
                                    
            
            
                | 109 |  |  |                 constant = None | 
            
                                                                        
                            
            
                                    
            
            
                | 110 |  |  |                 mode = self.padding_mode | 
            
                                                                        
                            
            
                                    
            
            
                | 111 |  |  |             pad_params = self.bounds_parameters | 
            
                                                                        
                            
            
                                    
            
            
                | 112 |  |  |             paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] | 
            
                                                                        
                            
            
                                    
            
            
                | 113 |  |  |             padded = np.pad(image.data, paddings, mode=mode, constant_values=constant)  # type: ignore[call-overload] | 
            
                                                                        
                            
            
                                    
            
            
                | 114 |  |  |             image.set_data(torch.as_tensor(padded)) | 
            
                                                                        
                            
            
                                    
            
            
                | 115 |  |  |             image.affine = new_affine | 
            
                                                                        
                            
            
                                    
            
            
                | 116 |  |  |         return subject | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |     def inverse(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         from .crop import Crop | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 121 |  |  |         return Crop(self.padding) | 
            
                                                        
            
                                    
            
            
                | 122 |  |  |  |