Passed
Pull Request — main (#1308)
by
unknown
01:47
created

Crop.crop_image()   A

Complexity

Conditions 2

Size

Total Lines 22
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 17
nop 4
dl 0
loc 22
rs 9.55
c 0
b 0
f 0
1
import copy
2
import nibabel as nib
3
import numpy as np
4
5
from ....data.subject import Subject
6
from ....data.image import Image
7
from .bounds_transform import BoundsTransform
8
from .bounds_transform import TypeBounds
9
10
11
class Crop(BoundsTransform):
12
    r"""Crop an image.
13
14
    Args:
15
        cropping: Tuple
16
            :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`
17
            defining the number of values cropped from the edges of each axis.
18
            If the initial shape of the image is
19
            :math:`W \times H \times D`, the final shape will be
20
            :math:`(- w_{ini} + W - w_{fin}) \times (- h_{ini} + H - h_{fin})
21
            \times (- d_{ini} + D - d_{fin})`.
22
            If only three values :math:`(w, h, d)` are provided, then
23
            :math:`w_{ini} = w_{fin} = w`,
24
            :math:`h_{ini} = h_{fin} = h` and
25
            :math:`d_{ini} = d_{fin} = d`.
26
            If only one value :math:`n` is provided, then
27
            :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin}
28
            = d_{ini} = d_{fin} = n`.
29
        copy: bool, optional
30
            This transform overwrites the copy argument of the base transform and 
31
            copies only the cropped patch, instead of the whole image.
32
            If ``True``, the cropped image will be copied to a new subject.
33
            If ``False``, the patch will be cropped in place. Default: ``True``.
34
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
35
            keyword arguments.
36
37
    .. seealso:: If you want to pass the output shape instead, please use
38
        :class:`~torchio.transforms.CropOrPad` instead.
39
    """
40
41
    def __init__(self, cropping: TypeBounds, copy=True, **kwargs):
42
        self.copy_patch = copy
43
        super().__init__(cropping, copy=False, **kwargs)
44
        self.cropping = cropping
45
        self.args_names = ['cropping']
46
47
    def apply_transform(self, sample: Subject) -> Subject:
48
        assert self.bounds_parameters is not None
49
        low = self.bounds_parameters[::2]
50
        high = self.bounds_parameters[1::2]
51
        index_ini = low
52
        index_fin = np.array(sample.spatial_shape) - high
53
        
54
        if self.copy_patch:
55
            # Create a new subject with only the cropped patch
56
            sample_attributes = {}
57
            
58
            # Copy all non-image attributes
59
            for key, value in sample.items():
60
                if key not in sample.get_images_dict(intensity_only=False, include=self.include, exclude=self.exclude).keys():
61
                    sample_attributes[key] = copy.deepcopy(value)
62
                else:
63
                    sample_attributes[key] = self.crop_image(value, index_ini, index_fin)
64
            cropped_sample = type(sample)(**sample_attributes)
65
            
66
            # Copy applied transforms history
67
            cropped_sample.applied_transforms = copy.deepcopy(sample.applied_transforms)
68
            
69
            cropped_sample.update_attributes()
70
            return cropped_sample
71
        else:
72
            # Crop in place
73
            for image in self.get_images(sample):
74
                self.crop_image(image, index_ini, index_fin)
75
            return sample
76
    
77
    def crop_image(self, image: Image, index_ini: tuple, index_fin: tuple) -> None:
78
        new_origin = nib.affines.apply_affine(image.affine, index_ini)
79
        new_affine = image.affine.copy()
80
        new_affine[:3, 3] = new_origin
81
        i0, j0, k0 = index_ini
82
        i1, j1, k1 = index_fin
83
        
84
        # Crop the image data
85
        if self.copy_patch:
86
            # Create a new image with the cropped data
87
            cropped_data = image.data[:, i0:i1, j0:j1, k0:k1].clone()
88
            new_image = type(image)(
89
                tensor=cropped_data,
90
                affine=new_affine,
91
                type=image.type,
92
                path=image.path,
93
            )
94
            return new_image
95
        else:
96
            image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone())
97
            image.affine = new_affine
98
            return image
99
100
    def inverse(self):
101
        from .pad import Pad
102
103
        return Pad(self.cropping)
104