Passed
Pull Request — master (#222)
by Fernando
03:12
created

torchio.transforms.augmentation.intensity.random_labels_to_image   A

Complexity

Total Complexity 42

Size/Duplication

Total Lines 263
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 149
dl 0
loc 263
rs 9.0399
c 0
b 0
f 0
wmc 42

9 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomLabelsToImage.parse_default_gaussian_parameters() 0 12 3
A RandomLabelsToImage.get_params() 0 12 2
A RandomLabelsToImage.__init__() 0 17 1
A RandomLabelsToImage.parse_gaussian_parameter() 0 18 4
B RandomLabelsToImage.parse_gaussian_parameters() 0 23 7
C RandomLabelsToImage.apply_transform() 0 49 9
A RandomLabelsToImage.parse_pv_label_maps() 0 14 4
A RandomLabelsToImage.generate_tissue() 0 10 1
C RandomLabelsToImage.parse_keys() 0 19 11

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_labels_to_image often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
2
from typing import Union, Tuple, Optional, Dict, Sequence
3
import torch
4
import numpy as np
5
from ....torchio import DATA, TypeData, TypeRangeFloat, TypeNumber, AFFINE, INTENSITY
6
from ....data.subject import Subject
7
from ....data.image import ScalarImage
8
from .. import RandomTransform
9
10
11
12
13
class RandomLabelsToImage(RandomTransform):
14
    MEAN_RANGE = (0.1, 0.9)
15
    STD_RANGE = (0.01, 0.1)
16
    r"""Generate an image from a segmentation.
17
18
    Based on the work by `Billot et al., A Learning Strategy for
19
    Contrast-agnostic MRI Segmentation <https://arxiv.org/abs/2003.01995>`_.
20
21
    Args:
22
        label_key: String designating the label map in the sample
23
            that will be used to generate the new image.
24
            Cannot be set at the same time as :py:attr:`pv_label_keys`.
25
        pv_label_keys: Sequence of strings designating the partial-volume (PV)
26
            label maps in the sample that will be used to generate the new
27
            image. Cannot be set at the same time as :py:attr:`label_key`.
28
        image_key: String designating the key to which the new volume will be
29
            saved. If this key corresponds to an already existing volume,
30
            voxels that have a value of 0 in the label maps will be filled with
31
            the corresponding values in the original volume.
32
        gaussian_parameters: Dictionary containing the mean and standard
33
            deviation for each label. For each value :math:`v`, if a tuple
34
            :math:`(a, b)` is provided then :math:`v \sim \mathcal{U}(a, b)`.
35
            If no value is given for a label, the value from
36
            :py:attr:`default_gaussian_parameters` will be used.
37
        default_gaussian_parameters: Dictionary containing the default
38
            mean and standard deviation used for all labels that are not
39
            defined in :py:attr:`gaussian_parameters`.
40
            Default values are ``(0.1, 0.9)`` for the mean and
41
            ``(0.01, 0.1)`` for the standard deviation.
42
        binarize: If ``True``, PV label maps will be binarized.
43
            Does not have any effects if not using PV label maps.
44
            Binarization is done taking the highest value per voxel
45
            in the different PV label maps.
46
        p: Probability that this transform will be applied.
47
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
48
49
    .. note:: It is recommended to blur the new images to make the result more
50
        realistic. See
51
        :py:class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`.
52
53
    Example:
54
        >>> import torchio
55
        >>> from torchio import RandomLabelsToImage, DATA, RescaleIntensity, Compose
56
        >>> from torchio.datasets import Colin27
57
        >>> colin = Colin27(2008)
58
        >>> # Using the default gaussian_parameters
59
        >>> transform = RandomLabelsToImage(label_key='cls')
60
        >>> # Using custom gaussian_parameters
61
        >>> label_values = colin['cls'][DATA].unique().round().long()
62
        >>> gaussian_parameters = {
63
        ...     label: {
64
        ...         'mean': i / len(label_values),
65
        ...         'std': 0.01
66
        ...     }
67
        ...     for i, label in enumerate(label_values)
68
        ... }
69
        >>> transform = RandomLabelsToImage(label_key='cls', gaussian_parameters=gaussian_parameters)
70
        >>> transformed = transform(colin)  # colin has a new key 'image' with the simulated image
71
        >>> # Filling holes of the simulated image with the original T1 image
72
        >>> rescale_transform = RescaleIntensity((0, 1), (1, 99))   # Rescale intensity before filling holes
73
        >>> simulation_transform = RandomLabelsToImage(
74
        ...     label_key='cls',
75
        ...     image_key='t1',
76
        ...     gaussian_parameters={0: {'mean': 0, 'std': 0}}
77
        ... )
78
        >>> transform = Compose([rescale_transform, simulation_transform])
79
        >>> transformed = transform(colin)  # colin's key 't1' has been replaced with the simulated image
80
    """
81
    def __init__(
82
            self,
83
            label_key: Optional[str] = None,
84
            pv_label_keys: Optional[Sequence[str]] = None,
85
            image_key: str = 'image',
86
            gaussian_parameters: Optional[Dict[Union[str, TypeNumber], Dict[str, TypeRangeFloat]]] = None,
87
            default_gaussian_parameters: Optional[Dict[str, TypeRangeFloat]] = None,
88
            binarize: bool = False,
89
            p: float = 1,
90
            seed: Optional[int] = None,
91
            ):
92
        super().__init__(p=p, seed=seed)
93
        self.label_key, self.pv_label_keys = self.parse_keys(label_key, pv_label_keys)
94
        self.default_gaussian_parameters = self.parse_default_gaussian_parameters(default_gaussian_parameters)
95
        self.gaussian_parameters = self.parse_gaussian_parameters(gaussian_parameters)
96
        self.image_key = image_key
97
        self.binarize = binarize
98
99
    @staticmethod
100
    def parse_keys(label_key, pv_label_keys):
101
        if label_key is not None and pv_label_keys is not None:
102
            raise ValueError('"label_key" and "pv_label_keys" can\'t be set at the same time.')
103
        if label_key is None and pv_label_keys is None:
104
            raise ValueError('One of "label_key" and "pv_label_keys" must be set.')
105
        if label_key is not None and not isinstance(label_key, str):
106
            raise TypeError(f'"label_key" must be a string, not {label_key}')
107
        if pv_label_keys is not None:
108
            try:
109
                iter(pv_label_keys)
110
            except TypeError:
111
                raise TypeError(f'"pv_label_keys" must be a sequence of strings, not {pv_label_keys}')
112
            for key in pv_label_keys:
113
                if not isinstance(key, str):
114
                    raise TypeError(f'Every key of "pv_label_keys" must be a string, found {key}')
115
            pv_label_keys = list(pv_label_keys)
116
117
        return label_key, pv_label_keys
118
119
    def apply_transform(self, sample: Subject) -> dict:
120
        random_parameters_images_dict = {}
121
        original_image = sample.get(self.image_key)
122
123
        if self.pv_label_keys is not None:
124
            label_map, affine = self.parse_pv_label_maps(self.pv_label_keys, sample)
125
            n_labels, *image_shape = label_map.shape
126
            labels = self.pv_label_keys
127
            values = list(range(n_labels))
128
129
            if self.binarize:
130
                # Take label with highest value in voxel
131
                max_label, label_map = label_map.max(dim=0)
132
                # Remove values where all labels are 0
133
                label_map[max_label == 0] = -1
134
135
        else:
136
            label_map = sample[self.label_key][DATA][0]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
137
            affine = sample[self.label_key][AFFINE]
138
            image_shape = label_map.shape
139
            values = label_map.unique()
140
            labels = [int(key) for key in values.round()]
141
142
        tissues = torch.zeros(image_shape)
143
144
        for i, label in enumerate(labels):
145
            mean, std = self.get_params(label)
146
            if self.pv_label_keys is not None and not self.binarize:
147
                mask = label_map[i]
148
            else:
149
                mask = label_map == values[i]
150
            tissues += self.generate_tissue(mask, mean, std)
151
152
            random_parameters_images_dict[label] = {
153
                'mean': mean,
154
                'std': std
155
            }
156
157
        final_image = ScalarImage(affine=affine, tensor=tissues)
158
159
        if original_image is not None:
160
            if self.pv_label_keys is not None and not self.binarize:
161
                label_map = label_map.sum(dim=0)
162
            bg_mask = label_map.unsqueeze(0) <= 0
163
            final_image[DATA][bg_mask] = original_image[DATA][bg_mask]
164
165
        sample.add_image(final_image, self.image_key)
166
        sample.add_transform(self, random_parameters_images_dict)
167
        return sample
168
169
    def parse_default_gaussian_parameters(self, default_gaussian_parameters):
170
        if default_gaussian_parameters is None:
171
            return {'mean': self.MEAN_RANGE, 'std': self.STD_RANGE}
172
173
        if list(default_gaussian_parameters.keys()) != ['mean', 'std']:
174
            raise KeyError(f'Default gaussian parameters {default_gaussian_parameters.keys()} do not '
175
                           f'match {["mean", "std"]}')
176
177
        mean = self.parse_gaussian_parameter(default_gaussian_parameters['mean'], 'mean')
178
        std = self.parse_gaussian_parameter(default_gaussian_parameters['std'], 'std')
179
180
        return {'mean': mean, 'std': std}
181
182
    def parse_gaussian_parameters(self, gaussian_parameters):
183
        if gaussian_parameters is None:
184
            gaussian_parameters = {}
185
186
        if self.pv_label_keys is not None:
187
            if not set(self.pv_label_keys).issuperset(gaussian_parameters.keys()):
188
                raise KeyError(f'Found keys in gaussian parameters {gaussian_parameters.keys()} '
189
                               f'not in pv_label_keys {self.pv_label_keys}')
190
191
        parsed_gaussian_parameters = {}
192
193
        for label_key, dictionary in gaussian_parameters.items():
194
            if 'mean' in dictionary:
195
                mean = self.parse_gaussian_parameter(dictionary['mean'], 'mean')
196
            else:
197
                mean = self.default_gaussian_parameters['mean']
198
            if 'std' in dictionary:
199
                std = self.parse_gaussian_parameter(dictionary['std'], 'std')
200
            else:
201
                std = self.default_gaussian_parameters['std']
202
            parsed_gaussian_parameters.update({label_key: {'mean': mean, 'std': std}})
203
204
        return parsed_gaussian_parameters
205
206
    @staticmethod
207
    def parse_gaussian_parameter(
208
            nums_range: TypeRangeFloat,
209
            name: str,
210
            ) -> Tuple[float, float]:
211
        if isinstance(nums_range, (int, float)):
212
            return nums_range, nums_range
213
214
        if len(nums_range) != 2:
215
            raise ValueError(
216
                f'If {name} is a sequence,'
217
                f' it must be of len 2, not {nums_range}')
218
        min_value, max_value = nums_range
219
        if min_value > max_value:
220
            raise ValueError(
221
                f'If {name} is a sequence, the second value must be'
222
                f' equal or greater than the first, not {nums_range}')
223
        return min_value, max_value
224
225
    @staticmethod
226
    def parse_pv_label_maps(
227
            pv_label_keys: Sequence[str],
228
            sample: dict,
229
            ) -> (TypeData, TypeData):
230
        try:
231
            label_map = torch.cat([sample[key][DATA] for key in pv_label_keys], dim=0)
232
        except RuntimeError:
233
            raise RuntimeError('PV label maps have different shapes, make sure they all have the same shapes.')
234
        affine = sample[pv_label_keys[0]][AFFINE]
235
        for key in pv_label_keys[1:]:
236
            if not np.array_equal(affine, sample[key][AFFINE]):
237
                raise RuntimeWarning('Be careful, PV label maps with different affines were found.')
238
        return label_map, affine
239
240
    def get_params(
241
            self,
242
            label: Union[str, TypeNumber]
243
            ) -> Tuple[float, float]:
244
        if label in self.gaussian_parameters:
245
            mean_range, std_range = self.gaussian_parameters[label]['mean'], self.gaussian_parameters[label]['std']
246
        else:
247
            mean_range, std_range = self.default_gaussian_parameters['mean'], self.default_gaussian_parameters['std']
248
249
        mean = torch.FloatTensor(1).uniform_(*mean_range).item()
250
        std = torch.FloatTensor(1).uniform_(*std_range).item()
251
        return mean, std
252
253
    @staticmethod
254
    def generate_tissue(
255
            data: TypeData,
256
            mean: TypeNumber,
257
            std: TypeNumber,
258
            ) -> TypeData:
259
        # Create the simulated tissue using a gaussian random variable
260
        data_shape = data.shape
261
        gaussian = torch.randn(data_shape) * std + mean
262
        return gaussian * data
263