Passed
Pull Request — master (#164)
by Fernando
01:24
created

HistogramStandardization.parse_landmarks_dict()   A

Complexity

Conditions 4

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nop 1
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
from pathlib import Path
2
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
3
import torch
4
import numpy as np
5
import nibabel as nib
6
from tqdm import tqdm
7
from ....torchio import DATA, TypePath
8
from ....data.io import read_image
9
from ....data.subject import Subject
10
from .normalization_transform import NormalizationTransform, TypeMaskingMethod
11
12
DEFAULT_CUTOFF = 0.01, 0.99
13
STANDARD_RANGE = 0, 100
14
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]
15
16
17
class HistogramStandardization(NormalizationTransform):
18
    """Perform histogram standardization of intensity values.
19
20
    See example in :py:func:`torchio.transforms.HistogramStandardization.train`.
21
22
    Args:
23
        landmarks_dict: Dictionary or path to a dictionary in which keys are
24
            image names in the sample and values are NumPy arrays or paths to
25
            NumPy arrays defining the landmarks after training with
26
            :py:meth:`torchio.transforms.HistogramStandardization.train`.
27
        masking_method: See
28
            :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`.
29
        p: Probability that this transform will be applied.
30
31
    Example:
32
        >>> from pathlib import Path
33
        >>> from torchio.transforms import HistogramStandardization
34
        >>>
35
        >>> landmarks_dict = {
36
        ...     't1': Path('t1_landmarks.npy'),
37
        ...     't2': Path('t2_landmarks.npy'),
38
        ... }
39
        >>>
40
        >>> transform = HistogramStandardization(landmarks_dict)
41
        >>> transform = HistogramStandardization('path_to_landmarks_dict.pth')
42
    """
43
    def __init__(
44
            self,
45
            landmarks_dict: TypeLandmarks,
46
            masking_method: TypeMaskingMethod = None,
47
            p: float = 1,
48
            ):
49
        super().__init__(masking_method=masking_method, p=p)
50
        self.landmarks_dict = self.parse_landmarks_dict(landmarks_dict)
51
52
    @staticmethod
53
    def parse_landmarks_dict(landmarks: TypeLandmarks) -> Dict[str, np.ndarray]:
54
        if isinstance(landmarks, (str, Path)):
55
            landmarks_dict = torch.load(landmarks_dict)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable landmarks_dict does not seem to be defined.
Loading history...
56
        else:
57
            landmarks_dict = landmarks
58
        for key, value in landmarks_dict.items():
59
            if isinstance(value, (str, Path)):
60
                landmarks_dict[key] = np.load(value)
61
        return landmarks_dict
62
63
    def apply_normalization(
64
            self,
65
            sample: Subject,
66
            image_name: str,
67
            mask: torch.Tensor,
68
            ) -> None:
69
        if image_name not in self.landmarks_dict:
70
            keys = tuple(self.landmarks_dict.keys())
71
            message = (
72
                f'Image name "{image_name}" should be a key in the'
73
                f' landmarks dictionary, whose keys are {keys}'
74
            )
75
            raise KeyError(message)
76
        image_dict = sample[image_name]
77
        landmarks = self.landmarks_dict[image_name]
78
        image_dict[DATA] = normalize(
79
            image_dict[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
80
            landmarks,
81
            mask=mask,
82
        )
83
84
    @classmethod
85
    def train(
86
            cls,
87
            images_paths: Sequence[TypePath],
88
            cutoff: Optional[Tuple[float, float]] = None,
89
            mask_path: Optional[TypePath] = None,
90
            masking_function: Optional[Callable] = None,
91
            output_path: Optional[TypePath] = None,
92
            ) -> np.ndarray:
93
        """Extract average histogram landmarks from images used for training.
94
95
        Args:
96
            images_paths: List of image paths used to train.
97
            cutoff: Optional minimum and maximum quantile values,
98
                respectively, that are used to select a range of intensity of
99
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
100
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
101
            mask_path: Optional path to a mask image to extract voxels used for
102
                training.
103
            masking_function: Optional function used to extract voxels used for
104
                training.
105
            output_path: Optional file path with extension ``.txt`` or
106
                ``.npy``, where the landmarks will be saved.
107
108
        Example:
109
110
            >>> from pathlib import Path
111
            >>> import numpy as np
112
            >>> from torchio.transforms import HistogramStandardization
113
            >>>
114
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
115
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
116
            >>>
117
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
118
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
119
            >>>
120
            >>> t1_landmarks = (
121
            ...     t1_landmarks_path
122
            ...     if t1_landmarks_path.is_file()
123
            ...     else HistogramStandardization.train(t1_paths)
124
            ... )
125
            >>> t2_landmarks = (
126
            ...     t2_landmarks_path
127
            ...     if t2_landmarks_path.is_file()
128
            ...     else HistogramStandardization.train(t2_paths)
129
            ... )
130
            >>>
131
            >>> landmarks_dict = {
132
            ...     't1': t1_landmarks,
133
            ...     't2': t2_landmarks,
134
            ... }
135
            >>>
136
            >>> transform = HistogramStandardization(landmarks_dict)
137
        """
138
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
139
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
140
        percentiles_database = []
141
        percentiles = _get_percentiles(percentiles_cutoff)
142
        for image_file_path in tqdm(images_paths):
143
            tensor, _ = read_image(image_file_path)
144
            data = tensor.numpy()
145
            if masking_function is not None:
146
                mask = masking_function(data)
147
            else:
148
                if mask_path is not None:
149
                    mask = nib.load(str(mask_path)).get_fdata()
150
                    mask = mask > 0
151
                else:
152
                    mask = np.ones_like(data, dtype=np.bool)
153
            percentile_values = np.percentile(data[mask], percentiles)
154
            percentiles_database.append(percentile_values)
155
        percentiles_database = np.vstack(percentiles_database)
156
        mapping = _get_average_mapping(percentiles_database)
157
158
        if output_path is not None:
159
            output_path = Path(output_path).expanduser()
160
            extension = output_path.suffix
161
            if extension == '.txt':
162
                modality = 'image'
163
                text = f'{modality} {" ".join(map(str, mapping))}'
164
                output_path.write_text(text)
165
            elif extension == '.npy':
166
                np.save(output_path, mapping)
167
        return mapping
168
169
170
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
171
    """Standardize the cutoff values given in the configuration.
172
173
    Computes percentile landmark normalization by default.
174
175
    """
176
    cutoff = np.asarray(cutoff)
177
    cutoff[0] = max(0., cutoff[0])
178
    cutoff[1] = min(1., cutoff[1])
179
    cutoff[0] = np.min([cutoff[0], 0.09])
180
    cutoff[1] = np.max([cutoff[1], 0.91])
181
    return cutoff
182
183
184
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
185
    """Map the landmarks of the database to the chosen range.
186
187
    Args:
188
        percentiles_database: Percentiles database over which to perform the
189
            averaging.
190
    """
191
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
192
    pc1 = percentiles_database[:, 0]
193
    pc2 = percentiles_database[:, -1]
194
    s1, s2 = STANDARD_RANGE
195
    slopes = (s2 - s1) / (pc2 - pc1)
196
    slopes = np.nan_to_num(slopes)
197
    intercepts = np.mean(s1 - slopes * pc1)
198
    num_images = len(percentiles_database)
199
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
200
    return final_map
201
202
203
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
204
    quartiles = np.arange(25, 100, 25).tolist()
205
    deciles = np.arange(10, 100, 10).tolist()
206
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
207
    percentiles = sorted(set(all_percentiles))
208
    return np.array(percentiles)
209
210
211
def normalize(
212
        tensor: torch.Tensor,
213
        landmarks: np.ndarray,
214
        mask: Optional[np.ndarray],
215
        cutoff: Optional[Tuple[float, float]] = None,
216
        epsilon: float = 1e-5,
217
        ) -> torch.Tensor:
218
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
219
    array = tensor.numpy()
220
    mapping = landmarks
221
222
    data = array
223
    shape = data.shape
224
    data = data.reshape(-1).astype(np.float32)
225
226
    if mask is None:
227
        mask = np.ones_like(data, np.bool)
228
    mask = mask.reshape(-1)
229
230
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
231
232
    quantiles_cutoff = _standardize_cutoff(cutoff_)
233
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
234
    percentiles = _get_percentiles(percentiles_cutoff)
235
    percentile_values = np.percentile(data[mask], percentiles)
236
237
    # Apply linear histogram standardization
238
    range_mapping = mapping[range_to_use]
239
    range_perc = percentile_values[range_to_use]
240
    diff_mapping = np.diff(range_mapping)
241
    diff_perc = np.diff(range_perc)
242
243
    # Handling the case where two landmarks are the same
244
    # for a given input image. This usually happens when
245
    # image background is not removed from the image.
246
    diff_perc[diff_perc < epsilon] = np.inf
247
248
    affine_map = np.zeros([2, len(range_to_use) - 1])
249
250
    # Compute slopes of the linear models
251
    affine_map[0] = diff_mapping / diff_perc
252
253
    # Compute intercepts of the linear models
254
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
255
256
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
257
    lin_img = affine_map[0, bin_id]
258
    aff_img = affine_map[1, bin_id]
259
    new_img = lin_img * data + aff_img
260
    new_img = new_img.reshape(shape)
261
    new_img = new_img.astype(np.float32)
262
    new_img = torch.from_numpy(new_img)
263
    return new_img
264
265
266
train = train_histogram = HistogramStandardization.train
267