Passed
Push — master ( 3c93a7...8f7525 )
by Fernando
01:09
created

RandomSwap.apply_transform()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
from typing import Optional, Tuple, Union, List
2
import torch
3
import numpy as np
4
from ....data.subject import Subject
5
from ....utils import to_tuple
6
from ....torchio import DATA, TypeTuple, TypeData, TypeTripletInt
7
from ... import IntensityTransform
8
from .. import RandomTransform
9
10
11
class RandomSwap(RandomTransform, IntensityTransform):
12
    r"""Randomly swap patches within an image.
13
14
    This is typically used in `context restoration for self-supervised learning
15
    <https://www.sciencedirect.com/science/article/pii/S1361841518304699>`_.
16
17
    Args:
18
        patch_size: Tuple of integers :math:`(w, h, d)` to swap patches
19
            of size :math:`h \times w \times d`.
20
            If a single number :math:`n` is provided, :math:`w = h = d = n`.
21
        num_iterations: Number of times that two patches will be swapped.
22
        p: Probability that this transform will be applied.
23
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
24
        keys: See :py:class:`~torchio.transforms.Transform`.
25
    """
26
    def __init__(
27
            self,
28
            patch_size: TypeTuple = 15,
29
            num_iterations: int = 100,
30
            p: float = 1,
31
            seed: Optional[int] = None,
32
            keys: Optional[List[str]] = None,
33
            ):
34
        super().__init__(p=p, seed=seed, keys=keys)
35
        self.patch_size = np.array(to_tuple(patch_size))
36
        self.num_iterations = self.parse_num_iterations(num_iterations)
37
38
    @staticmethod
39
    def parse_num_iterations(num_iterations):
40
        if not isinstance(num_iterations, int):
41
            raise TypeError('num_iterations must be an int,'
42
                            f'not {num_iterations}')
43
        if num_iterations < 0:
44
            raise ValueError('num_iterations must be positive,'
45
                             f'not {num_iterations}')
46
        return num_iterations
47
48
    @staticmethod
49
    def get_params(
50
            tensor: torch.Tensor,
51
            patch_size: np.ndarray,
52
            num_iterations: int,
53
            ) -> List[Tuple[np.ndarray, np.ndarray]]:
54
        spatial_shape = tensor.shape[-3:]
55
        locations = []
56
        for _ in range(num_iterations):
57
            first_ini, first_fin = get_random_indices_from_shape(
58
                spatial_shape,
59
                patch_size,
60
            )
61
            while True:
62
                second_ini, second_fin = get_random_indices_from_shape(
63
                    spatial_shape,
64
                    patch_size,
65
                )
66
                larger_than_initial = np.all(second_ini >= first_ini)
67
                less_than_final = np.all(second_fin <= first_fin)
68
                if larger_than_initial and less_than_final:
69
                    continue  # patches overlap
70
                else:
71
                    break  # patches don't overlap
72
            locations.append((first_ini, second_ini))
73
        return locations
74
75
    def apply_transform(self, sample: Subject) -> dict:
76
        for image in self.get_images(sample):
77
            tensor = image[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
78
            locations = self.get_params(
79
                tensor, self.patch_size, self.num_iterations)
80
            image[DATA] = swap(tensor, self.patch_size, locations)
81
        return sample
82
83
84
def swap(
85
        tensor: torch.Tensor,
86
        patch_size: TypeTuple,
87
        locations: List[Tuple[np.ndarray, np.ndarray]],
88
        ) -> None:
89
    tensor = tensor.clone()
90
    patch_size = np.array(patch_size)
91
    for first_ini, second_ini in locations:
92
        first_fin = first_ini + patch_size
93
        second_fin = second_ini + patch_size
94
        first_patch = crop(tensor, first_ini, first_fin)
95
        second_patch = crop(tensor, second_ini, second_fin).clone()
96
        insert(tensor, first_patch, second_ini)
97
        insert(tensor, second_patch, first_ini)
98
    return tensor
99
100
101
def insert(tensor: TypeData, patch: TypeData, index_ini: np.ndarray) -> None:
102
    index_fin = index_ini + np.array(patch.shape[-3:])
103
    i_ini, j_ini, k_ini = index_ini
104
    i_fin, j_fin, k_fin = index_fin
105
    tensor[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch
106
107
108
def crop(
109
        image: Union[np.ndarray, torch.Tensor],
110
        index_ini: np.ndarray,
111
        index_fin: np.ndarray,
112
        ) -> Union[np.ndarray, torch.Tensor]:
113
    i_ini, j_ini, k_ini = index_ini
114
    i_fin, j_fin, k_fin = index_fin
115
    return image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
116
117
118
def get_random_indices_from_shape(
119
        spatial_shape: TypeTripletInt,
120
        patch_size: TypeTripletInt,
121
        ) -> Tuple[np.ndarray, np.ndarray]:
122
    shape_array = np.array(spatial_shape)
123
    patch_size_array = np.array(patch_size)
124
    max_index_ini = shape_array - patch_size_array
125
    if (max_index_ini < 0).any():
126
        message = (
127
            f'Patch size {patch_size} cannot be'
128
            f' larger than image spatial shape {spatial_shape}'
129
        )
130
        raise ValueError(message)
131
    max_index_ini = max_index_ini.astype(np.uint16)
132
    coordinates = []
133
    for max_coordinate in max_index_ini.tolist():
134
        if max_coordinate == 0:
135
            coordinate = 0
136
        else:
137
            coordinate = torch.randint(max_coordinate, size=(1,)).item()
138
        coordinates.append(coordinate)
139
    index_ini = np.array(coordinates, np.uint16)
140
    index_fin = index_ini + patch_size_array
141
    return index_ini, index_fin
142