Passed
Pull Request — master (#353)
by Fernando
01:18
created

HistogramStandardization.parse_landmarks()   A

Complexity

Conditions 5

Size

Total Lines 17
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 1
dl 0
loc 17
rs 9.2333
c 0
b 0
f 0
1
from pathlib import Path
2
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
3
4
import torch
5
import numpy as np
6
from tqdm import tqdm
7
8
from ....torchio import DATA, TypePath
9
from ....data.io import read_image
10
from ....data.subject import Subject
11
from .normalization_transform import NormalizationTransform, TypeMaskingMethod
12
13
DEFAULT_CUTOFF = 0.01, 0.99
14
STANDARD_RANGE = 0, 100
15
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]
16
17
18
class HistogramStandardization(NormalizationTransform):
19
    """Perform histogram standardization of intensity values.
20
21
    Implementation of `New variants of a method of MRI scale
22
    standardization <https://ieeexplore.ieee.org/document/836373>`_.
23
24
    See example in :py:func:`torchio.transforms.HistogramStandardization.train`.
25
26
    Args:
27
        landmarks: Dictionary (or path to a PyTorch file with ``.pt`` or ``.pth``
28
            extension in which a dictionary has been saved) whose keys are
29
            image names in the subject and values are NumPy arrays or paths to
30
            NumPy arrays defining the landmarks after training with
31
            :py:meth:`torchio.transforms.HistogramStandardization.train`.
32
        masking_method: See
33
            :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`.
34
        p: Probability that this transform will be applied.
35
36
    Example:
37
        >>> import torch
38
        >>> from pathlib import Path
39
        >>> from torchio.transforms import HistogramStandardization
40
        >>>
41
        >>> landmarks = {
42
        ...     't1': 't1_landmarks.npy',
43
        ...     't2': 't2_landmarks.npy',
44
        ... }
45
        >>> transform = HistogramStandardization(landmarks)
46
        >>>
47
        >>> torch.save(landmarks, 'path_to_landmarks.pth')
48
        >>> transform = HistogramStandardization('path_to_landmarks.pth')
49
    """
50
    def __init__(
51
            self,
52
            landmarks: TypeLandmarks,
53
            masking_method: TypeMaskingMethod = None,
54
            p: float = 1,
55
            ):
56
        super().__init__(masking_method=masking_method, p=p)
57
        self.landmarks = landmarks
58
        self.landmarks_dict = self._parse_landmarks(landmarks)
59
        self.args_names = 'landmarks', 'masking_method'
60
61
    @staticmethod
62
    def _parse_landmarks(landmarks: TypeLandmarks) -> Dict[str, np.ndarray]:
63
        if isinstance(landmarks, (str, Path)):
64
            path = Path(landmarks)
65
            if path.suffix not in ('.pt', '.pth'):
66
                message = (
67
                    'The landmarks file must have extension .pt or .pth,'
68
                    f' not "{path.suffix}"'
69
                )
70
                raise ValueError(message)
71
            landmarks_dict = torch.load(path)
72
        else:
73
            landmarks_dict = landmarks
74
        for key, value in landmarks_dict.items():
75
            if isinstance(value, (str, Path)):
76
                landmarks_dict[key] = np.load(value)
77
        return landmarks_dict
78
79
    def apply_normalization(
80
            self,
81
            subject: Subject,
82
            image_name: str,
83
            mask: torch.Tensor,
84
            ) -> None:
85
        if image_name not in self.landmarks_dict:
86
            keys = tuple(self.landmarks_dict.keys())
87
            message = (
88
                f'Image name "{image_name}" should be a key in the'
89
                f' landmarks dictionary, whose keys are {keys}'
90
            )
91
            raise KeyError(message)
92
        image_dict = subject[image_name]
93
        landmarks = self.landmarks_dict[image_name]
94
        image_dict[DATA] = normalize(
95
            image_dict[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
96
            landmarks,
97
            mask=mask,
98
        )
99
100
    @classmethod
101
    def train(
102
            cls,
103
            images_paths: Sequence[TypePath],
104
            cutoff: Optional[Tuple[float, float]] = None,
105
            mask_path: Optional[TypePath] = None,
106
            masking_function: Optional[Callable] = None,
107
            output_path: Optional[TypePath] = None,
108
            ) -> np.ndarray:
109
        """Extract average histogram landmarks from images used for training.
110
111
        Args:
112
            images_paths: List of image paths used to train.
113
            cutoff: Optional minimum and maximum quantile values,
114
                respectively, that are used to select a range of intensity of
115
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
116
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
117
            mask_path: Optional path to a mask image to extract voxels used for
118
                training.
119
            masking_function: Optional function used to extract voxels used for
120
                training.
121
            output_path: Optional file path with extension ``.txt`` or
122
                ``.npy``, where the landmarks will be saved.
123
124
        Example:
125
126
            >>> import torch
127
            >>> import numpy as np
128
            >>> from pathlib import Path
129
            >>> from torchio.transforms import HistogramStandardization
130
            >>>
131
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
132
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
133
            >>>
134
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
135
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
136
            >>>
137
            >>> t1_landmarks = (
138
            ...     t1_landmarks_path
139
            ...     if t1_landmarks_path.is_file()
140
            ...     else HistogramStandardization.train(t1_paths)
141
            ... )
142
            >>> torch.save(t1_landmarks, t1_landmarks_path)
143
            >>>
144
            >>> t2_landmarks = (
145
            ...     t2_landmarks_path
146
            ...     if t2_landmarks_path.is_file()
147
            ...     else HistogramStandardization.train(t2_paths)
148
            ... )
149
            >>> torch.save(t2_landmarks, t2_landmarks_path)
150
            >>>
151
            >>> landmarks_dict = {
152
            ...     't1': t1_landmarks,
153
            ...     't2': t2_landmarks,
154
            ... }
155
            >>>
156
            >>> transform = HistogramStandardization(landmarks_dict)
157
        """
158
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
159
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
160
        percentiles_database = []
161
        percentiles = _get_percentiles(percentiles_cutoff)
162
        for image_file_path in tqdm(images_paths):
163
            tensor, _ = read_image(image_file_path)
164
            data = tensor.numpy()
165
            if masking_function is not None:
166
                mask = masking_function(data)
167
            else:
168
                if mask_path is not None:
169
                    mask, _ = read_image(mask_path)
170
                    mask = mask.numpy() > 0
171
                else:
172
                    mask = np.ones_like(data, dtype=np.bool)
173
            percentile_values = np.percentile(data[mask], percentiles)
174
            percentiles_database.append(percentile_values)
175
        percentiles_database = np.vstack(percentiles_database)
176
        mapping = _get_average_mapping(percentiles_database)
177
178
        if output_path is not None:
179
            output_path = Path(output_path).expanduser()
180
            extension = output_path.suffix
181
            if extension == '.txt':
182
                modality = 'image'
183
                text = f'{modality} {" ".join(map(str, mapping))}'
184
                output_path.write_text(text)
185
            elif extension == '.npy':
186
                np.save(output_path, mapping)
187
        return mapping
188
189
190
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
191
    """Standardize the cutoff values given in the configuration.
192
193
    Computes percentile landmark normalization by default.
194
195
    """
196
    cutoff = np.asarray(cutoff)
197
    cutoff[0] = max(0., cutoff[0])
198
    cutoff[1] = min(1., cutoff[1])
199
    cutoff[0] = np.min([cutoff[0], 0.09])
200
    cutoff[1] = np.max([cutoff[1], 0.91])
201
    return cutoff
202
203
204
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
205
    """Map the landmarks of the database to the chosen range.
206
207
    Args:
208
        percentiles_database: Percentiles database over which to perform the
209
            averaging.
210
    """
211
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
212
    pc1 = percentiles_database[:, 0]
213
    pc2 = percentiles_database[:, -1]
214
    s1, s2 = STANDARD_RANGE
215
    slopes = (s2 - s1) / (pc2 - pc1)
216
    slopes = np.nan_to_num(slopes)
217
    intercepts = np.mean(s1 - slopes * pc1)
218
    num_images = len(percentiles_database)
219
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
220
    return final_map
221
222
223
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
224
    quartiles = np.arange(25, 100, 25).tolist()
225
    deciles = np.arange(10, 100, 10).tolist()
226
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
227
    percentiles = sorted(set(all_percentiles))
228
    return np.array(percentiles)
229
230
231
def normalize(
232
        tensor: torch.Tensor,
233
        landmarks: np.ndarray,
234
        mask: Optional[np.ndarray],
235
        cutoff: Optional[Tuple[float, float]] = None,
236
        epsilon: float = 1e-5,
237
        ) -> torch.Tensor:
238
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
239
    array = tensor.numpy()
240
    mapping = landmarks
241
242
    data = array
243
    shape = data.shape
244
    data = data.reshape(-1).astype(np.float32)
245
246
    if mask is None:
247
        mask = np.ones_like(data, np.bool)
248
    mask = mask.reshape(-1)
249
250
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
251
252
    quantiles_cutoff = _standardize_cutoff(cutoff_)
253
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
254
    percentiles = _get_percentiles(percentiles_cutoff)
255
    percentile_values = np.percentile(data[mask], percentiles)
256
257
    # Apply linear histogram standardization
258
    range_mapping = mapping[range_to_use]
259
    range_perc = percentile_values[range_to_use]
260
    diff_mapping = np.diff(range_mapping)
261
    diff_perc = np.diff(range_perc)
262
263
    # Handling the case where two landmarks are the same
264
    # for a given input image. This usually happens when
265
    # image background is not removed from the image.
266
    diff_perc[diff_perc < epsilon] = np.inf
267
268
    affine_map = np.zeros([2, len(range_to_use) - 1])
269
270
    # Compute slopes of the linear models
271
    affine_map[0] = diff_mapping / diff_perc
272
273
    # Compute intercepts of the linear models
274
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
275
276
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
277
    lin_img = affine_map[0, bin_id]
278
    aff_img = affine_map[1, bin_id]
279
    new_img = lin_img * data + aff_img
280
    new_img = new_img.reshape(shape)
281
    new_img = new_img.astype(np.float32)
282
    new_img = torch.from_numpy(new_img)
283
    return new_img
284
285
286
train = train_histogram = HistogramStandardization.train
287