Passed
Pull Request — master (#621)
by
unknown
03:38
created

_get_percentiles()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
from pathlib import Path
2
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
3
4
import sys
5
import torch
6
import numpy as np
7
from tqdm import tqdm
8
9
from ....typing import TypePath
10
from ....data.io import read_image
11
from ....data.subject import Subject
12
from .normalization_transform import NormalizationTransform, TypeMaskingMethod
13
14
DEFAULT_CUTOFF = 0.01, 0.99
15
STANDARD_RANGE = 0, 100
16
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]
17
18
19
class HistogramStandardization(NormalizationTransform):
20
    """Perform histogram standardization of intensity values.
21
22
    Implementation of `New variants of a method of MRI scale
23
    standardization <https://ieeexplore.ieee.org/document/836373>`_.
24
25
    See example in :func:`torchio.transforms.HistogramStandardization.train`.
26
27
    Args:
28
        landmarks: Dictionary (or path to a PyTorch file with ``.pt`` or ``.pth``
29
            extension in which a dictionary has been saved) whose keys are
30
            image names in the subject and values are NumPy arrays or paths to
31
            NumPy arrays defining the landmarks after training with
32
            :meth:`torchio.transforms.HistogramStandardization.train`.
33
        masking_method: See
34
            :class:`~torchio.transforms.preprocessing.intensity.NormalizationTransform`.
35
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
36
            keyword arguments.
37
38
    Example:
39
        >>> import torch
40
        >>> import torchio as tio
41
        >>> landmarks = {
42
        ...     't1': 't1_landmarks.npy',
43
        ...     't2': 't2_landmarks.npy',
44
        ... }
45
        >>> transform = tio.HistogramStandardization(landmarks)
46
        >>> torch.save(landmarks, 'path_to_landmarks.pth')
47
        >>> transform = tio.HistogramStandardization('path_to_landmarks.pth')
48
    """  # noqa: E501
49
    def __init__(
50
            self,
51
            landmarks: TypeLandmarks,
52
            masking_method: TypeMaskingMethod = None,
53
            **kwargs
54
            ):
55
        super().__init__(masking_method=masking_method, **kwargs)
56
        self.landmarks = landmarks
57
        self.landmarks_dict = self._parse_landmarks(landmarks)
58
        self.args_names = 'landmarks', 'masking_method'
59
60
    @staticmethod
61
    def _parse_landmarks(landmarks: TypeLandmarks) -> Dict[str, np.ndarray]:
62
        if isinstance(landmarks, (str, Path)):
63
            path = Path(landmarks)
64
            if path.suffix not in ('.pt', '.pth'):
65
                message = (
66
                    'The landmarks file must have extension .pt or .pth,'
67
                    f' not "{path.suffix}"'
68
                )
69
                raise ValueError(message)
70
            landmarks_dict = torch.load(path)
71
        else:
72
            landmarks_dict = landmarks
73
        for key, value in landmarks_dict.items():
74
            if isinstance(value, (str, Path)):
75
                landmarks_dict[key] = np.load(value)
76
        return landmarks_dict
77
78
    def apply_normalization(
79
            self,
80
            subject: Subject,
81
            image_name: str,
82
            mask: torch.Tensor,
83
            ) -> None:
84
        if image_name not in self.landmarks_dict:
85
            keys = tuple(self.landmarks_dict.keys())
86
            message = (
87
                f'Image name "{image_name}" should be a key in the'
88
                f' landmarks dictionary, whose keys are {keys}'
89
            )
90
            raise KeyError(message)
91
        image = subject[image_name]
92
        landmarks = self.landmarks_dict[image_name]
93
        normalized = normalize(image.data, landmarks, mask=mask)
94
        image.set_data(normalized)
95
96
    @classmethod
97
    def train(
98
            cls,
99
            images_paths: Sequence[TypePath],
100
            cutoff: Optional[Tuple[float, float]] = None,
101
            mask_path: Optional[Union[Sequence[TypePath], TypePath]] = None,
102
            masking_function: Optional[Callable] = None,
103
            output_path: Optional[TypePath] = None,
104
            ) -> np.ndarray:
105
        """Extract average histogram landmarks from images used for training.
106
107
        Args:
108
            images_paths: List of image paths used to train.
109
            cutoff: Optional minimum and maximum quantile values,
110
                respectively, that are used to select a range of intensity of
111
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
112
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
113
            mask_path: Path (or list of paths) to a binary image that will be
114
                used to select the voxels use to compute the stats during
115
                histogram training. If ``None``, all voxels in the image will
116
                be used.
117
            masking_function: Function used to extract voxels used for
118
                histogram training.
119
            output_path: Optional file path with extension ``.txt`` or
120
                ``.npy``, where the landmarks will be saved.
121
122
        Example:
123
124
            >>> import torch
125
            >>> import numpy as np
126
            >>> from pathlib import Path
127
            >>> from torchio.transforms import HistogramStandardization
128
            >>>
129
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
130
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
131
            >>>
132
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
133
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
134
            >>>
135
            >>> t1_landmarks = (
136
            ...     t1_landmarks_path
137
            ...     if t1_landmarks_path.is_file()
138
            ...     else HistogramStandardization.train(t1_paths)
139
            ... )
140
            >>> torch.save(t1_landmarks, t1_landmarks_path)
141
            >>>
142
            >>> t2_landmarks = (
143
            ...     t2_landmarks_path
144
            ...     if t2_landmarks_path.is_file()
145
            ...     else HistogramStandardization.train(t2_paths)
146
            ... )
147
            >>> torch.save(t2_landmarks, t2_landmarks_path)
148
            >>>
149
            >>> landmarks_dict = {
150
            ...     't1': t1_landmarks,
151
            ...     't2': t2_landmarks,
152
            ... }
153
            >>>
154
            >>> transform = HistogramStandardization(landmarks_dict)
155
        """  # noqa: E501
156
        is_masks_list = isinstance(mask_path, Sequence)
157
        if is_masks_list and len(mask_path) != len(images_paths):
158
            message = (
159
                f'Different number of images ({len(images_paths)})'
160
                f' and mask ({len(mask_path)}) paths found'
161
            )
162
            raise ValueError(message)
163
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
164
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
165
        percentiles_database = []
166
        percentiles = _get_percentiles(percentiles_cutoff)
167
        for i, image_file_path in enumerate(tqdm(images_paths)):
168
            tensor, _ = read_image(image_file_path)
169
            if masking_function is not None:
170
                mask = masking_function(tensor)
171
            else:
172
                if mask_path is None:
173
                    mask = np.ones_like(tensor, dtype=bool)
174
                else:
175
                    if is_masks_list:
176
                        path = mask_path[i]
177
                    else:
178
                        path = mask_path
179
                    mask, _ = read_image(path)
180
                    mask = mask.numpy() > 0
181
182
            # check empty tenzor
183
            if tensor.min() == tensor.max():
184
                print("Warning. Empty tensor", image_file_path, "skipped.", file=sys.stderr)
185
                continue
186
            array = tensor.numpy()
187
            percentile_values = np.percentile(array[mask], percentiles)
188
            percentiles_database.append(percentile_values)
189
        percentiles_database = np.vstack(percentiles_database)
190
        mapping = _get_average_mapping(percentiles_database)
191
192
        if output_path is not None:
193
            output_path = Path(output_path).expanduser()
194
            extension = output_path.suffix
195
            if extension == '.txt':
196
                modality = 'image'
197
                text = f'{modality} {" ".join(map(str, mapping))}'
198
                output_path.write_text(text)
199
            elif extension == '.npy':
200
                np.save(output_path, mapping)
201
        return mapping
202
203
204
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
205
    """Standardize the cutoff values given in the configuration.
206
207
    Computes percentile landmark normalization by default.
208
209
    """
210
    cutoff = np.asarray(cutoff)
211
    cutoff[0] = max(0, cutoff[0])
212
    cutoff[1] = min(1, cutoff[1])
213
    cutoff[0] = np.min([cutoff[0], 0.09])
214
    cutoff[1] = np.max([cutoff[1], 0.91])
215
    return cutoff
216
217
218
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
219
    """Map the landmarks of the database to the chosen range.
220
221
    Args:
222
        percentiles_database: Percentiles database over which to perform the
223
            averaging.
224
    """
225
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
226
    pc1 = percentiles_database[:, 0]
227
    pc2 = percentiles_database[:, -1]
228
    s1, s2 = STANDARD_RANGE
229
    slopes = (s2 - s1) / (pc2 - pc1)
230
    slopes = np.nan_to_num(slopes)
231
    intercepts = np.mean(s1 - slopes * pc1)
232
    num_images = len(percentiles_database)
233
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
234
    return final_map
235
236
237
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
238
    quartiles = np.arange(25, 100, 25).tolist()
239
    deciles = np.arange(10, 100, 10).tolist()
240
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
241
    percentiles = sorted(set(all_percentiles))
242
    return np.array(percentiles)
243
244
245
def normalize(
246
        tensor: torch.Tensor,
247
        landmarks: np.ndarray,
248
        mask: Optional[np.ndarray],
249
        cutoff: Optional[Tuple[float, float]] = None,
250
        epsilon: float = 1e-5,
251
        ) -> torch.Tensor:
252
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
253
    array = tensor.numpy()
254
    mapping = landmarks
255
256
    data = array
257
    shape = data.shape
258
    data = data.reshape(-1).astype(np.float32)
259
260
    if mask is None:
261
        mask = np.ones_like(data, bool)
262
    mask = mask.reshape(-1)
263
264
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
265
266
    quantiles_cutoff = _standardize_cutoff(cutoff_)
267
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
268
    percentiles = _get_percentiles(percentiles_cutoff)
269
    percentile_values = np.percentile(data[mask], percentiles)
270
271
    # Apply linear histogram standardization
272
    range_mapping = mapping[range_to_use]
273
    range_perc = percentile_values[range_to_use]
274
    diff_mapping = np.diff(range_mapping)
275
    diff_perc = np.diff(range_perc)
276
277
    # Handling the case where two landmarks are the same
278
    # for a given input image. This usually happens when
279
    # image background is not removed from the image.
280
    diff_perc[diff_perc < epsilon] = np.inf
281
282
    affine_map = np.zeros([2, len(range_to_use) - 1])
283
284
    # Compute slopes of the linear models
285
    affine_map[0] = diff_mapping / diff_perc
286
287
    # Compute intercepts of the linear models
288
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
289
290
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
291
    lin_img = affine_map[0, bin_id]
292
    aff_img = affine_map[1, bin_id]
293
    new_img = lin_img * data + aff_img
294
    new_img = new_img.reshape(shape)
295
    new_img = new_img.astype(np.float32)
296
    new_img = torch.as_tensor(new_img)
297
    return new_img
298
299
300
# train_histogram kept for backward compatibility
301
train = train_histogram = HistogramStandardization.train
302