Passed
Pull Request — master (#222)
by Fernando
01:01
created

torchio.transforms.augmentation.intensity.random_labels_to_image   A

Complexity

Total Complexity 42

Size/Duplication

Total Lines 256
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 149
dl 0
loc 256
rs 9.0399
c 0
b 0
f 0
wmc 42

9 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomLabelsToImage.parse_default_gaussian_parameters() 0 12 3
A RandomLabelsToImage.get_params() 0 12 2
A RandomLabelsToImage.__init__() 0 17 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
B RandomLabelsToImage.parse_gaussian_parameters() 0 23 7
C RandomLabelsToImage.apply_transform() 0 49 9
A RandomLabelsToImage.parse_pv_label_maps() 0 14 4
A RandomLabelsToImage.generate_tissue() 0 10 1
C RandomLabelsToImage.parse_keys() 0 19 11

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
2
from typing import Union, Tuple, Optional, Dict, Sequence
3
import torch
4
import numpy as np
5
from ....torchio import DATA, TypeData, TypeRangeFloat, TypeNumber, AFFINE, INTENSITY
6
from ....data.subject import Subject
7
from ....data.image import Image
8
from .. import RandomTransform
9
10
MEAN_RANGE = (0.1, 0.9)
11
STD_RANGE = (0.01, 0.1)
12
13
14
class RandomLabelsToImage(RandomTransform):
15
    r"""Generate a random image from a binary label map or a set of
16
     partial volume (PV) label maps.
17
18
    Args:
19
        label_key: String designating the label map in the sample
20
            that will be used to generate the new image.
21
            Cannot be set at the same time as pv_label_keys.
22
        pv_label_keys: Sequence of strings designating the PV label maps in
23
            the sample that will be used to generate the new image.
24
            Cannot be set at the same time as label_key.
25
        image_key: String designating the key to which the new volume will be saved.
26
            If this key corresponds to an already existing volume, zero elements from
27
            the label maps will be filled with elements of the original volume.
28
        gaussian_parameters: Dictionary containing the mean and standard deviation for
29
            each label. For each value :math:`v`, if a tuple
30
            :math:`(a, b)` is provided then
31
            :math:`v \sim \mathcal{U}(a, b)`.
32
            If no value is given for a label, value from default_gaussian_parameters
33
            will be used.
34
        default_gaussian_parameters: Dictionary containing the default
35
            mean and standard deviation used for all labels that are not
36
            defined in gaussian_parameters.
37
            Default values are :math:`(0.1, 0.9)` for the mean and
38
            :math:`(0.01, 0.1)` for the standard deviation.
39
        binarize: Boolean to tell if PV label maps should be binarized.
40
            Does not have any effects if not using PV label maps.
41
            Binarization is done taking the highest value per voxel
42
            in the different PV label maps.
43
        p: Probability that this transform will be applied.
44
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
45
46
    .. note:: Generated images using label maps are unrealistic, therefore it is recommended
47
        to blur the new images. See :py:class:`~torchio.transforms.augmentation.RandomBlur`.
48
49
    Example:
50
        >>> import torchio
51
        >>> from torchio import RandomLabelsToImage, DATA, RescaleIntensity, Compose
52
        >>> from torchio.datasets import Colin27
53
        >>> colin = Colin27(2008)
54
        >>> # Using the default gaussian_parameters
55
        >>> transform = RandomLabelsToImage(label_key='cls')
56
        >>> # Using custom gaussian_parameters
57
        >>> label_values = colin['cls'][DATA].unique().round()
58
        >>> gaussian_parameters = {
59
        >>>     int(label): {
60
        >>>         'mean': i / len(label_values), 'std': 0.01
61
        >>>     } for i, label in enumerate(label_values)
62
        >>> }
63
        >>> transform = RandomLabelsToImage(label_key='cls', gaussian_parameters=gaussian_parameters)
64
        >>> # Filling holes of the simulated image with the original T1 image
65
        >>> rescale_transform = RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
66
        >>> simulation_transform = RandomLabelsToImage(
67
        >>>     label_key='cls',
68
        >>>     image_key='t1',
69
        >>>     gaussian_parameters={0: {'mean': 0, 'std': 0}}
70
        >>> )
71
        >>> transform = Compose([rescale_transform, simulation_transform])
72
        >>> transformed = transform(colin)  # colin has a new key 'image' with the simulated image
73
    """
74
    def __init__(
75
            self,
76
            label_key: Optional[str] = None,
77
            pv_label_keys: Optional[Sequence[str]] = None,
78
            image_key: str = 'image',
79
            gaussian_parameters: Optional[Dict[Union[str, TypeNumber], Dict[str, TypeRangeFloat]]] = None,
80
            default_gaussian_parameters: Optional[Dict[str, TypeRangeFloat]] = None,
81
            binarize: bool = False,
82
            p: float = 1,
83
            seed: Optional[int] = None,
84
            ):
85
        super().__init__(p=p, seed=seed)
86
        self.label_key, self.pv_label_keys = self.parse_keys(label_key, pv_label_keys)
87
        self.default_gaussian_parameters = self.parse_default_gaussian_parameters(default_gaussian_parameters)
88
        self.gaussian_parameters = self.parse_gaussian_parameters(gaussian_parameters)
89
        self.image_key = image_key
90
        self.binarize = binarize
91
92
    @staticmethod
93
    def parse_keys(label_key, pv_label_keys):
94
        if label_key is not None and pv_label_keys is not None:
95
            raise ValueError('"label_key" and "pv_label_keys" can\'t be set at the same time.')
96
        if label_key is None and pv_label_keys is None:
97
            raise ValueError('One of "label_key" and "pv_label_keys" must be set.')
98
        if label_key is not None and not isinstance(label_key, str):
99
            raise TypeError(f'"label_key" must be a string, not {label_key}')
100
        if pv_label_keys is not None:
101
            try:
102
                iter(pv_label_keys)
103
            except TypeError:
104
                raise TypeError(f'"pv_label_keys" must be a sequence of strings, not {pv_label_keys}')
105
            for key in pv_label_keys:
106
                if not isinstance(key, str):
107
                    raise TypeError(f'Every key of "pv_label_keys" must be a string, found {key}')
108
            pv_label_keys = list(pv_label_keys)
109
110
        return label_key, pv_label_keys
111
112
    def apply_transform(self, sample: Subject) -> dict:
113
        random_parameters_images_dict = {}
114
        original_image = sample.get(self.image_key)
115
116
        if self.pv_label_keys is not None:
117
            label_map, affine = self.parse_pv_label_maps(self.pv_label_keys, sample)
118
            n_labels, *image_shape = label_map.shape
119
            labels = self.pv_label_keys
120
            values = list(range(n_labels))
121
122
            if self.binarize:
123
                # Take label with highest value in voxel
124
                max_label, label_map = label_map.max(dim=0)
125
                # Remove values where all labels are 0
126
                label_map[max_label == 0] = -1
127
128
        else:
129
            label_map = sample[self.label_key][DATA][0]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
130
            affine = sample[self.label_key][AFFINE]
131
            image_shape = label_map.shape
132
            values = label_map.unique()
133
            labels = [int(key) for key in values.round()]
134
135
        tissues = torch.zeros(image_shape)
136
137
        for i, label in enumerate(labels):
138
            mean, std = self.get_params(label)
139
            if self.pv_label_keys is not None and not self.binarize:
140
                mask = label_map[i]
141
            else:
142
                mask = label_map == values[i]
143
            tissues += self.generate_tissue(mask, mean, std)
144
145
            random_parameters_images_dict[label] = {
146
                'mean': mean,
147
                'std': std
148
            }
149
150
        final_image = Image(type=INTENSITY, affine=affine, tensor=tissues)
151
152
        if original_image is not None:
153
            if self.pv_label_keys is not None and not self.binarize:
154
                label_map = label_map.sum(dim=0)
155
            background_indices = label_map.unsqueeze(0) <= 0
156
            final_image[DATA][background_indices] = original_image[DATA][background_indices]
157
158
        sample[self.image_key] = final_image
159
        sample.add_transform(self, random_parameters_images_dict)
160
        return sample
161
162
    def parse_default_gaussian_parameters(self, default_gaussian_parameters):
163
        if default_gaussian_parameters is None:
164
            return {'mean': MEAN_RANGE, 'std': STD_RANGE}
165
166
        if list(default_gaussian_parameters.keys()) != ['mean', 'std']:
167
            raise KeyError(f'Default gaussian parameters {default_gaussian_parameters.keys()} do not '
168
                           f'match {["mean", "std"]}')
169
170
        mean = self.parse_gaussian_parameter(default_gaussian_parameters['mean'], 'mean')
171
        std = self.parse_gaussian_parameter(default_gaussian_parameters['std'], 'std')
172
173
        return {'mean': mean, 'std': std}
174
175
    def parse_gaussian_parameters(self, gaussian_parameters):
176
        if gaussian_parameters is None:
177
            gaussian_parameters = {}
178
179
        if self.pv_label_keys is not None:
180
            if not set(self.pv_label_keys).issuperset(gaussian_parameters.keys()):
181
                raise KeyError(f'Found keys in gaussian parameters {gaussian_parameters.keys()} '
182
                               f'not in pv_label_keys {self.pv_label_keys}')
183
184
        parsed_gaussian_parameters = {}
185
186
        for label_key, dictionary in gaussian_parameters.items():
187
            if 'mean' in dictionary:
188
                mean = self.parse_gaussian_parameter(dictionary['mean'], 'mean')
189
            else:
190
                mean = self.default_gaussian_parameters['mean']
191
            if 'std' in dictionary:
192
                std = self.parse_gaussian_parameter(dictionary['std'], 'std')
193
            else:
194
                std = self.default_gaussian_parameters['std']
195
            parsed_gaussian_parameters.update({label_key: {'mean': mean, 'std': std}})
196
197
        return parsed_gaussian_parameters
198
199
    @staticmethod
200
    def parse_gaussian_parameter(
201
            nums_range: TypeRangeFloat,
202
            name: str,
203
            ) -> Tuple[float, float]:
204
        if isinstance(nums_range, (int, float)):
205
            return nums_range, nums_range
206
207
        if len(nums_range) != 2:
208
            raise ValueError(
209
                f'If {name} is a sequence,'
210
                f' it must be of len 2, not {nums_range}')
211
        min_value, max_value = nums_range
212
        if min_value > max_value:
213
            raise ValueError(
214
                f'If {name} is a sequence, the second value must be'
215
                f' equal or greater than the first, not {nums_range}')
216
        return min_value, max_value
217
218
    @staticmethod
219
    def parse_pv_label_maps(
220
            pv_label_keys: Sequence[str],
221
            sample: dict,
222
            ) -> (TypeData, TypeData):
223
        try:
224
            label_map = torch.cat([sample[key][DATA] for key in pv_label_keys], dim=0)
225
        except RuntimeError:
226
            raise RuntimeError('PV label maps have different shapes, make sure they all have the same shapes.')
227
        affine = sample[pv_label_keys[0]][AFFINE]
228
        for key in pv_label_keys[1:]:
229
            if not np.array_equal(affine, sample[key][AFFINE]):
230
                raise RuntimeWarning('Be careful, PV label maps with different affines were found.')
231
        return label_map, affine
232
233
    def get_params(
234
            self,
235
            label: Union[str, TypeNumber]
236
            ) -> Tuple[float, float]:
237
        if label in self.gaussian_parameters:
238
            mean_range, std_range = self.gaussian_parameters[label]['mean'], self.gaussian_parameters[label]['std']
239
        else:
240
            mean_range, std_range = self.default_gaussian_parameters['mean'], self.default_gaussian_parameters['std']
241
242
        mean = torch.FloatTensor(1).uniform_(*mean_range).item()
243
        std = torch.FloatTensor(1).uniform_(*std_range).item()
244
        return mean, std
245
246
    @staticmethod
247
    def generate_tissue(
248
            data: TypeData,
249
            mean: TypeNumber,
250
            std: TypeNumber,
251
            ) -> TypeData:
252
        # Create the simulated tissue using a gaussian random variable
253
        data_shape = data.shape
254
        gaussian = torch.randn(data_shape) * std + mean
255
        return gaussian * data
256