HistogramStandardization.train()   C
last analyzed

Complexity

Conditions 11

Size

Total Lines 101
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 101
rs 5.4
c 0
b 0
f 0
cc 11
nop 6

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.transforms.preprocessing.intensity.histogram_standardization.HistogramStandardization.train() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from pathlib import Path
2
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
3
4
import torch
5
import numpy as np
6
from tqdm import tqdm
7
8
from ....typing import TypePath
9
from ....data.io import read_image
10
from ....data.subject import Subject
11
from .normalization_transform import NormalizationTransform, TypeMaskingMethod
12
13
DEFAULT_CUTOFF = 0.01, 0.99
14
STANDARD_RANGE = 0, 100
15
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]
16
17
18
class HistogramStandardization(NormalizationTransform):
19
    """Perform histogram standardization of intensity values.
20
21
    Implementation of `New variants of a method of MRI scale
22
    standardization <https://ieeexplore.ieee.org/document/836373>`_.
23
24
    See example in :func:`torchio.transforms.HistogramStandardization.train`.
25
26
    Args:
27
        landmarks: Dictionary (or path to a PyTorch file with ``.pt`` or ``.pth``
28
            extension in which a dictionary has been saved) whose keys are
29
            image names in the subject and values are NumPy arrays or paths to
30
            NumPy arrays defining the landmarks after training with
31
            :meth:`torchio.transforms.HistogramStandardization.train`.
32
        masking_method: See
33
            :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`.
34
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
35
            keyword arguments.
36
37
    Example:
38
        >>> import torch
39
        >>> import torchio as tio
40
        >>> landmarks = {
41
        ...     't1': 't1_landmarks.npy',
42
        ...     't2': 't2_landmarks.npy',
43
        ... }
44
        >>> transform = tio.HistogramStandardization(landmarks)
45
        >>> torch.save(landmarks, 'path_to_landmarks.pth')
46
        >>> transform = tio.HistogramStandardization('path_to_landmarks.pth')
47
    """  # noqa: E501
48
    def __init__(
49
            self,
50
            landmarks: TypeLandmarks,
51
            masking_method: TypeMaskingMethod = None,
52
            **kwargs
53
            ):
54
        super().__init__(masking_method=masking_method, **kwargs)
55
        self.landmarks = landmarks
56
        self.landmarks_dict = self._parse_landmarks(landmarks)
57
        self.args_names = 'landmarks', 'masking_method'
58
59
    @staticmethod
60
    def _parse_landmarks(landmarks: TypeLandmarks) -> Dict[str, np.ndarray]:
61
        if isinstance(landmarks, (str, Path)):
62
            path = Path(landmarks)
63
            if path.suffix not in ('.pt', '.pth'):
64
                message = (
65
                    'The landmarks file must have extension .pt or .pth,'
66
                    f' not "{path.suffix}"'
67
                )
68
                raise ValueError(message)
69
            landmarks_dict = torch.load(path)
70
        else:
71
            landmarks_dict = landmarks
72
        for key, value in landmarks_dict.items():
73
            if isinstance(value, (str, Path)):
74
                landmarks_dict[key] = np.load(value)
75
        return landmarks_dict
76
77
    def apply_normalization(
78
            self,
79
            subject: Subject,
80
            image_name: str,
81
            mask: torch.Tensor,
82
            ) -> None:
83
        if image_name not in self.landmarks_dict:
84
            keys = tuple(self.landmarks_dict.keys())
85
            message = (
86
                f'Image name "{image_name}" should be a key in the'
87
                f' landmarks dictionary, whose keys are {keys}'
88
            )
89
            raise KeyError(message)
90
        image = subject[image_name]
91
        landmarks = self.landmarks_dict[image_name]
92
        normalized = normalize(image.data, landmarks, mask=mask)
93
        image.set_data(normalized)
94
95
    @classmethod
96
    def train(
97
            cls,
98
            images_paths: Sequence[TypePath],
99
            cutoff: Optional[Tuple[float, float]] = None,
100
            mask_path: Optional[Union[Sequence[TypePath], TypePath]] = None,
101
            masking_function: Optional[Callable] = None,
102
            output_path: Optional[TypePath] = None,
103
            ) -> np.ndarray:
104
        """Extract average histogram landmarks from images used for training.
105
106
        Args:
107
            images_paths: List of image paths used to train.
108
            cutoff: Optional minimum and maximum quantile values,
109
                respectively, that are used to select a range of intensity of
110
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
111
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
112
            mask_path: Path (or list of paths) to a binary image that will be
113
                used to select the voxels use to compute the stats during
114
                histogram training. If ``None``, all voxels in the image will
115
                be used.
116
            masking_function: Function used to extract voxels used for
117
                histogram training.
118
            output_path: Optional file path with extension ``.txt`` or
119
                ``.npy``, where the landmarks will be saved.
120
121
        Example:
122
123
            >>> import torch
124
            >>> import numpy as np
125
            >>> from pathlib import Path
126
            >>> from torchio.transforms import HistogramStandardization
127
            >>>
128
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
129
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
130
            >>>
131
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
132
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
133
            >>>
134
            >>> t1_landmarks = (
135
            ...     t1_landmarks_path
136
            ...     if t1_landmarks_path.is_file()
137
            ...     else HistogramStandardization.train(t1_paths)
138
            ... )
139
            >>> torch.save(t1_landmarks, t1_landmarks_path)
140
            >>>
141
            >>> t2_landmarks = (
142
            ...     t2_landmarks_path
143
            ...     if t2_landmarks_path.is_file()
144
            ...     else HistogramStandardization.train(t2_paths)
145
            ... )
146
            >>> torch.save(t2_landmarks, t2_landmarks_path)
147
            >>>
148
            >>> landmarks_dict = {
149
            ...     't1': t1_landmarks,
150
            ...     't2': t2_landmarks,
151
            ... }
152
            >>>
153
            >>> transform = HistogramStandardization(landmarks_dict)
154
        """  # noqa: E501
155
        is_masks_list = isinstance(mask_path, Sequence)
156
        if is_masks_list and len(mask_path) != len(images_paths):
157
            message = (
158
                f'Different number of images ({len(images_paths)})'
159
                f' and mask ({len(mask_path)}) paths found'
160
            )
161
            raise ValueError(message)
162
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
163
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
164
        percentiles_database = []
165
        percentiles = _get_percentiles(percentiles_cutoff)
166
        for i, image_file_path in enumerate(tqdm(images_paths)):
167
            tensor, _ = read_image(image_file_path)
168
            if masking_function is not None:
169
                mask = masking_function(tensor)
170
            else:
171
                if mask_path is None:
172
                    mask = np.ones_like(tensor, dtype=bool)
173
                else:
174
                    if is_masks_list:
175
                        path = mask_path[i]
176
                    else:
177
                        path = mask_path
178
                    mask, _ = read_image(path)
179
                    mask = mask.numpy() > 0
180
            array = tensor.numpy()
181
            percentile_values = np.percentile(array[mask], percentiles)
182
            percentiles_database.append(percentile_values)
183
        percentiles_database = np.vstack(percentiles_database)
184
        mapping = _get_average_mapping(percentiles_database)
185
186
        if output_path is not None:
187
            output_path = Path(output_path).expanduser()
188
            extension = output_path.suffix
189
            if extension == '.txt':
190
                modality = 'image'
191
                text = f'{modality} {" ".join(map(str, mapping))}'
192
                output_path.write_text(text)
193
            elif extension == '.npy':
194
                np.save(output_path, mapping)
195
        return mapping
196
197
198
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
199
    """Standardize the cutoff values given in the configuration.
200
201
    Computes percentile landmark normalization by default.
202
203
    """
204
    cutoff = np.asarray(cutoff)
205
    cutoff[0] = max(0, cutoff[0])
206
    cutoff[1] = min(1, cutoff[1])
207
    cutoff[0] = np.min([cutoff[0], 0.09])
208
    cutoff[1] = np.max([cutoff[1], 0.91])
209
    return cutoff
210
211
212
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
213
    """Map the landmarks of the database to the chosen range.
214
215
    Args:
216
        percentiles_database: Percentiles database over which to perform the
217
            averaging.
218
    """
219
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
220
    pc1 = percentiles_database[:, 0]
221
    pc2 = percentiles_database[:, -1]
222
    s1, s2 = STANDARD_RANGE
223
    slopes = (s2 - s1) / (pc2 - pc1)
224
    slopes = np.nan_to_num(slopes)
225
    intercepts = np.mean(s1 - slopes * pc1)
226
    num_images = len(percentiles_database)
227
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
228
    return final_map
229
230
231
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
232
    quartiles = np.arange(25, 100, 25).tolist()
233
    deciles = np.arange(10, 100, 10).tolist()
234
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
235
    percentiles = sorted(set(all_percentiles))
236
    return np.array(percentiles)
237
238
239
def normalize(
240
        tensor: torch.Tensor,
241
        landmarks: np.ndarray,
242
        mask: Optional[np.ndarray],
243
        cutoff: Optional[Tuple[float, float]] = None,
244
        epsilon: float = 1e-5,
245
        ) -> torch.Tensor:
246
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
247
    array = tensor.numpy()
248
    mapping = landmarks
249
250
    data = array
251
    shape = data.shape
252
    data = data.reshape(-1).astype(np.float32)
253
254
    if mask is None:
255
        mask = np.ones_like(data, bool)
256
    mask = mask.reshape(-1)
257
258
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
259
260
    quantiles_cutoff = _standardize_cutoff(cutoff_)
261
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
262
    percentiles = _get_percentiles(percentiles_cutoff)
263
    percentile_values = np.percentile(data[mask], percentiles)
264
265
    # Apply linear histogram standardization
266
    range_mapping = mapping[range_to_use]
267
    range_perc = percentile_values[range_to_use]
268
    diff_mapping = np.diff(range_mapping)
269
    diff_perc = np.diff(range_perc)
270
271
    # Handling the case where two landmarks are the same
272
    # for a given input image. This usually happens when
273
    # image background is not removed from the image.
274
    diff_perc[diff_perc < epsilon] = np.inf
275
276
    affine_map = np.zeros([2, len(range_to_use) - 1])
277
278
    # Compute slopes of the linear models
279
    affine_map[0] = diff_mapping / diff_perc
280
281
    # Compute intercepts of the linear models
282
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
283
284
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
285
    lin_img = affine_map[0, bin_id]
286
    aff_img = affine_map[1, bin_id]
287
    new_img = lin_img * data + aff_img
288
    new_img = new_img.reshape(shape)
289
    new_img = new_img.astype(np.float32)
290
    new_img = torch.as_tensor(new_img)
291
    return new_img
292
293
294
# train_histogram kept for backward compatibility
295
train = train_histogram = HistogramStandardization.train
296