Passed
Push — master ( 753ab1...a5fd0f )
by Fernando
01:15
created

HistogramStandardization.parse_landmarks()   A

Complexity

Conditions 5

Size

Total Lines 17
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 1
dl 0
loc 17
rs 9.2333
c 0
b 0
f 0
1
from pathlib import Path
2
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
3
import torch
4
import numpy as np
5
import nibabel as nib
6
from tqdm import tqdm
7
from ....torchio import DATA, TypePath
8
from ....data.io import read_image
9
from ....data.subject import Subject
10
from .normalization_transform import NormalizationTransform, TypeMaskingMethod
11
12
DEFAULT_CUTOFF = 0.01, 0.99
13
STANDARD_RANGE = 0, 100
14
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]
15
16
17
class HistogramStandardization(NormalizationTransform):
18
    """Perform histogram standardization of intensity values.
19
20
    See example in :py:func:`torchio.transforms.HistogramStandardization.train`.
21
22
    Args:
23
        landmarks: Dictionary (or path to a PyTorch file with ``.pt`` or ``.pth``
24
            extension in which a dictionary has been saved) whose keys are
25
            image names in the sample and values are NumPy arrays or paths to
26
            NumPy arrays defining the landmarks after training with
27
            :py:meth:`torchio.transforms.HistogramStandardization.train`.
28
        masking_method: See
29
            :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`.
30
        p: Probability that this transform will be applied.
31
32
    Example:
33
        >>> import torch
34
        >>> from pathlib import Path
35
        >>> from torchio.transforms import HistogramStandardization
36
        >>>
37
        >>> landmarks = {
38
        ...     't1': 't1_landmarks.npy',
39
        ...     't2': 't2_landmarks.npy',
40
        ... }
41
        >>> transform = HistogramStandardization(landmarks)
42
        >>>
43
        >>> torch.save(landmarks, 'path_to_landmarks.pth')
44
        >>> transform = HistogramStandardization('path_to_landmarks.pth')
45
    """
46
    def __init__(
47
            self,
48
            landmarks: TypeLandmarks,
49
            masking_method: TypeMaskingMethod = None,
50
            p: float = 1,
51
            ):
52
        super().__init__(masking_method=masking_method, p=p)
53
        self.landmarks_dict = self.parse_landmarks(landmarks)
54
55
    @staticmethod
56
    def parse_landmarks(landmarks: TypeLandmarks) -> Dict[str, np.ndarray]:
57
        if isinstance(landmarks, (str, Path)):
58
            path = Path(landmarks)
59
            if not path.suffix in ('.pt', '.pth'):
60
                message = (
61
                    'The landmarks file must have extension .pt or .pth,'
62
                    f' not "{path.suffix}"'
63
                )
64
                raise ValueError(message)
65
            landmarks_dict = torch.load(path)
66
        else:
67
            landmarks_dict = landmarks
68
        for key, value in landmarks_dict.items():
69
            if isinstance(value, (str, Path)):
70
                landmarks_dict[key] = np.load(value)
71
        return landmarks_dict
72
73
    def apply_normalization(
74
            self,
75
            sample: Subject,
76
            image_name: str,
77
            mask: torch.Tensor,
78
            ) -> None:
79
        if image_name not in self.landmarks_dict:
80
            keys = tuple(self.landmarks_dict.keys())
81
            message = (
82
                f'Image name "{image_name}" should be a key in the'
83
                f' landmarks dictionary, whose keys are {keys}'
84
            )
85
            raise KeyError(message)
86
        image_dict = sample[image_name]
87
        landmarks = self.landmarks_dict[image_name]
88
        image_dict[DATA] = normalize(
89
            image_dict[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
90
            landmarks,
91
            mask=mask,
92
        )
93
94
    @classmethod
95
    def train(
96
            cls,
97
            images_paths: Sequence[TypePath],
98
            cutoff: Optional[Tuple[float, float]] = None,
99
            mask_path: Optional[TypePath] = None,
100
            masking_function: Optional[Callable] = None,
101
            output_path: Optional[TypePath] = None,
102
            ) -> np.ndarray:
103
        """Extract average histogram landmarks from images used for training.
104
105
        Args:
106
            images_paths: List of image paths used to train.
107
            cutoff: Optional minimum and maximum quantile values,
108
                respectively, that are used to select a range of intensity of
109
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
110
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
111
            mask_path: Optional path to a mask image to extract voxels used for
112
                training.
113
            masking_function: Optional function used to extract voxels used for
114
                training.
115
            output_path: Optional file path with extension ``.txt`` or
116
                ``.npy``, where the landmarks will be saved.
117
118
        Example:
119
120
            >>> import torch
121
            >>> import numpy as np
122
            >>> from pathlib import Path
123
            >>> from torchio.transforms import HistogramStandardization
124
            >>>
125
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
126
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
127
            >>>
128
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
129
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
130
            >>>
131
            >>> t1_landmarks = (
132
            ...     t1_landmarks_path
133
            ...     if t1_landmarks_path.is_file()
134
            ...     else HistogramStandardization.train(t1_paths)
135
            ... )
136
            >>> torch.save(t1_landmarks, t1_landmarks_path)
137
            >>>
138
            >>> t2_landmarks = (
139
            ...     t2_landmarks_path
140
            ...     if t2_landmarks_path.is_file()
141
            ...     else HistogramStandardization.train(t2_paths)
142
            ... )
143
            >>> torch.save(t2_landmarks, t2_landmarks_path)
144
            >>>
145
            >>> landmarks_dict = {
146
            ...     't1': t1_landmarks,
147
            ...     't2': t2_landmarks,
148
            ... }
149
            >>>
150
            >>> transform = HistogramStandardization(landmarks_dict)
151
        """
152
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
153
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
154
        percentiles_database = []
155
        percentiles = _get_percentiles(percentiles_cutoff)
156
        for image_file_path in tqdm(images_paths):
157
            tensor, _ = read_image(image_file_path)
158
            data = tensor.numpy()
159
            if masking_function is not None:
160
                mask = masking_function(data)
161
            else:
162
                if mask_path is not None:
163
                    mask = nib.load(str(mask_path)).get_fdata()
164
                    mask = mask > 0
165
                else:
166
                    mask = np.ones_like(data, dtype=np.bool)
167
            percentile_values = np.percentile(data[mask], percentiles)
168
            percentiles_database.append(percentile_values)
169
        percentiles_database = np.vstack(percentiles_database)
170
        mapping = _get_average_mapping(percentiles_database)
171
172
        if output_path is not None:
173
            output_path = Path(output_path).expanduser()
174
            extension = output_path.suffix
175
            if extension == '.txt':
176
                modality = 'image'
177
                text = f'{modality} {" ".join(map(str, mapping))}'
178
                output_path.write_text(text)
179
            elif extension == '.npy':
180
                np.save(output_path, mapping)
181
        return mapping
182
183
184
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
185
    """Standardize the cutoff values given in the configuration.
186
187
    Computes percentile landmark normalization by default.
188
189
    """
190
    cutoff = np.asarray(cutoff)
191
    cutoff[0] = max(0., cutoff[0])
192
    cutoff[1] = min(1., cutoff[1])
193
    cutoff[0] = np.min([cutoff[0], 0.09])
194
    cutoff[1] = np.max([cutoff[1], 0.91])
195
    return cutoff
196
197
198
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
199
    """Map the landmarks of the database to the chosen range.
200
201
    Args:
202
        percentiles_database: Percentiles database over which to perform the
203
            averaging.
204
    """
205
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
206
    pc1 = percentiles_database[:, 0]
207
    pc2 = percentiles_database[:, -1]
208
    s1, s2 = STANDARD_RANGE
209
    slopes = (s2 - s1) / (pc2 - pc1)
210
    slopes = np.nan_to_num(slopes)
211
    intercepts = np.mean(s1 - slopes * pc1)
212
    num_images = len(percentiles_database)
213
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
214
    return final_map
215
216
217
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
218
    quartiles = np.arange(25, 100, 25).tolist()
219
    deciles = np.arange(10, 100, 10).tolist()
220
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
221
    percentiles = sorted(set(all_percentiles))
222
    return np.array(percentiles)
223
224
225
def normalize(
226
        tensor: torch.Tensor,
227
        landmarks: np.ndarray,
228
        mask: Optional[np.ndarray],
229
        cutoff: Optional[Tuple[float, float]] = None,
230
        epsilon: float = 1e-5,
231
        ) -> torch.Tensor:
232
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
233
    array = tensor.numpy()
234
    mapping = landmarks
235
236
    data = array
237
    shape = data.shape
238
    data = data.reshape(-1).astype(np.float32)
239
240
    if mask is None:
241
        mask = np.ones_like(data, np.bool)
242
    mask = mask.reshape(-1)
243
244
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
245
246
    quantiles_cutoff = _standardize_cutoff(cutoff_)
247
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
248
    percentiles = _get_percentiles(percentiles_cutoff)
249
    percentile_values = np.percentile(data[mask], percentiles)
250
251
    # Apply linear histogram standardization
252
    range_mapping = mapping[range_to_use]
253
    range_perc = percentile_values[range_to_use]
254
    diff_mapping = np.diff(range_mapping)
255
    diff_perc = np.diff(range_perc)
256
257
    # Handling the case where two landmarks are the same
258
    # for a given input image. This usually happens when
259
    # image background is not removed from the image.
260
    diff_perc[diff_perc < epsilon] = np.inf
261
262
    affine_map = np.zeros([2, len(range_to_use) - 1])
263
264
    # Compute slopes of the linear models
265
    affine_map[0] = diff_mapping / diff_perc
266
267
    # Compute intercepts of the linear models
268
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
269
270
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
271
    lin_img = affine_map[0, bin_id]
272
    aff_img = affine_map[1, bin_id]
273
    new_img = lin_img * data + aff_img
274
    new_img = new_img.reshape(shape)
275
    new_img = new_img.astype(np.float32)
276
    new_img = torch.from_numpy(new_img)
277
    return new_img
278
279
280
train = train_histogram = HistogramStandardization.train
281