Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

HistogramStandardization.train()   B

Complexity

Conditions 8

Size

Total Lines 84
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 35
dl 0
loc 84
rs 7.1733
c 0
b 0
f 0
cc 8
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Adapted from NiftyNet
3
"""
4
5
from pathlib import Path
6
from typing import Dict, Callable, Tuple, Sequence, Union, Optional
7
import torch
8
import numpy as np
9
import nibabel as nib
10
from tqdm import tqdm
11
from ....torchio import DATA, TypePath, TypeCallable
12
from ....data.io import read_image
13
from ....data.subject import Subject
14
from . import NormalizationTransform
15
16
DEFAULT_CUTOFF = 0.01, 0.99
17
STANDARD_RANGE = 0, 100
18
19
20
class HistogramStandardization(NormalizationTransform):
21
    """Perform histogram standardization of intensity values.
22
23
    See example in :py:func:`torchio.transforms.HistogramStandardization.train`.
24
25
    Args:
26
        landmarks_dict: Dictionary in which keys are image names in the sample
27
            and values are NumPy arrays defining the landmarks after training
28
            with :py:meth:`torchio.transforms.HistogramStandardization.train`.
29
        masking_method: See
30
            :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`.
31
        p: Probability that this transform will be applied.
32
    """
33
    def __init__(
34
            self,
35
            landmarks_dict: Dict[str, np.ndarray],
36
            masking_method: Union[str, TypeCallable, None] = None,
37
            p: float = 1,
38
            ):
39
        super().__init__(masking_method=masking_method, p=p)
40
        self.landmarks_dict = landmarks_dict
41
42
    def apply_normalization(
43
            self,
44
            sample: Subject,
45
            image_name: str,
46
            mask: torch.Tensor,
47
            ) -> None:
48
        if image_name not in self.landmarks_dict:
49
            keys = tuple(self.landmarks_dict.keys())
50
            message = (
51
                f'Image name "{image_name}" should be a key in the'
52
                f' landmarks dictionary, whose keys are {keys}'
53
            )
54
            raise KeyError(message)
55
        image_dict = sample[image_name]
56
        landmarks = self.landmarks_dict[image_name]
57
        image_dict[DATA] = normalize(
58
            image_dict[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
59
            landmarks,
60
            mask=mask,
61
        )
62
63
    @classmethod
64
    def train(
65
            cls,
66
            images_paths: Sequence[TypePath],
67
            cutoff: Optional[Tuple[float, float]] = None,
68
            mask_path: Optional[TypePath] = None,
69
            masking_function: Optional[Callable] = None,
70
            output_path: Optional[TypePath] = None,
71
            ) -> np.ndarray:
72
        """Extract average histogram landmarks from images used for training.
73
74
        Args:
75
            images_paths: List of image paths used to train.
76
            cutoff: Optional minimum and maximum quantile values,
77
                respectively, that are used to select a range of intensity of
78
                interest. Equivalent to :math:`pc_1` and :math:`pc_2` in
79
                `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_.
80
            mask_path: Optional path to a mask image to extract voxels used for
81
                training.
82
            masking_function: Optional function used to extract voxels used for
83
                training.
84
            output_path: Optional file path with extension ``.txt`` or
85
                ``.npy``, where the landmarks will be saved.
86
87
        Example:
88
89
            >>> from pathlib import Path
90
            >>> import numpy as np
91
            >>> from torchio.transforms import HistogramStandardization
92
            >>>
93
            >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz']
94
            >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz']
95
            >>>
96
            >>> t1_landmarks_path = Path('t1_landmarks.npy')
97
            >>> t2_landmarks_path = Path('t2_landmarks.npy')
98
            >>>
99
            >>> t1_landmarks = (
100
            ...     np.load(t1_landmarks_path)
101
            ...     if t1_landmarks_path.is_file()
102
            ...     else HistogramStandardization.train(t1_paths)
103
            ... )
104
            >>> t2_landmarks = (
105
            ...     np.load(t2_landmarks_path)
106
            ...     if t2_landmarks_path.is_file()
107
            ...     else HistogramStandardization.train(t2_paths)
108
            ... )
109
            >>>
110
            >>> landmarks_dict = {
111
            ...     't1': t1_landmarks,
112
            ...     't2': t2_landmarks,
113
            ... }
114
            >>>
115
            >>> transform = HistogramStandardization(landmarks_dict)
116
        """
117
        quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
118
        percentiles_cutoff = 100 * np.array(quantiles_cutoff)
119
        percentiles_database = []
120
        percentiles = _get_percentiles(percentiles_cutoff)
121
        for image_file_path in tqdm(images_paths):
122
            tensor, _ = read_image(image_file_path)
123
            data = tensor.numpy()
124
            if masking_function is not None:
125
                mask = masking_function(data)
126
            else:
127
                if mask_path is not None:
128
                    mask = nib.load(str(mask_path)).get_fdata()
129
                    mask = mask > 0
130
                else:
131
                    mask = np.ones_like(data, dtype=np.bool)
132
            percentile_values = np.percentile(data[mask], percentiles)
133
            percentiles_database.append(percentile_values)
134
        percentiles_database = np.vstack(percentiles_database)
135
        mapping = _get_average_mapping(percentiles_database)
136
137
        if output_path is not None:
138
            output_path = Path(output_path).expanduser()
139
            extension = output_path.suffix
140
            if extension == '.txt':
141
                modality = 'image'
142
                text = f'{modality} {" ".join(map(str, mapping))}'
143
                output_path.write_text(text)
144
            elif extension == '.npy':
145
                np.save(output_path, mapping)
146
        return mapping
147
148
149
def _standardize_cutoff(cutoff: np.ndarray) -> np.ndarray:
150
    """Standardize the cutoff values given in the configuration.
151
152
    Computes percentile landmark normalization by default.
153
154
    """
155
    cutoff = np.asarray(cutoff)
156
    cutoff[0] = max(0., cutoff[0])
157
    cutoff[1] = min(1., cutoff[1])
158
    cutoff[0] = np.min([cutoff[0], 0.09])
159
    cutoff[1] = np.max([cutoff[1], 0.91])
160
    return cutoff
161
162
163
def _get_average_mapping(percentiles_database: np.ndarray) -> np.ndarray:
164
    """Map the landmarks of the database to the chosen range.
165
166
    Args:
167
        percentiles_database: Percentiles database over which to perform the
168
            averaging.
169
    """
170
    # Assuming percentiles_database.shape == (num_data_points, num_percentiles)
171
    pc1 = percentiles_database[:, 0]
172
    pc2 = percentiles_database[:, -1]
173
    s1, s2 = STANDARD_RANGE
174
    slopes = (s2 - s1) / (pc2 - pc1)
175
    slopes = np.nan_to_num(slopes)
176
    intercepts = np.mean(s1 - slopes * pc1)
177
    num_images = len(percentiles_database)
178
    final_map = slopes.dot(percentiles_database) / num_images + intercepts
179
    return final_map
180
181
182
def _get_percentiles(percentiles_cutoff: Tuple[float, float]) -> np.ndarray:
183
    quartiles = np.arange(25, 100, 25).tolist()
184
    deciles = np.arange(10, 100, 10).tolist()
185
    all_percentiles = list(percentiles_cutoff) + quartiles + deciles
186
    percentiles = sorted(set(all_percentiles))
187
    return np.array(percentiles)
188
189
190
def normalize(
191
        tensor: torch.Tensor,
192
        landmarks: np.ndarray,
193
        mask: Optional[np.ndarray],
194
        cutoff: Optional[Tuple[float, float]] = None,
195
        epsilon: float = 1e-5,
196
        ) -> torch.Tensor:
197
    cutoff_ = DEFAULT_CUTOFF if cutoff is None else cutoff
198
    array = tensor.numpy()
199
    mapping = landmarks
200
201
    data = array
202
    shape = data.shape
203
    data = data.reshape(-1).astype(np.float32)
204
205
    if mask is None:
206
        mask = np.ones_like(data, np.bool)
207
    mask = mask.reshape(-1)
208
209
    range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
210
211
    quantiles_cutoff = _standardize_cutoff(cutoff_)
212
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
213
    percentiles = _get_percentiles(percentiles_cutoff)
214
    percentile_values = np.percentile(data[mask], percentiles)
215
216
    # Apply linear histogram standardization
217
    range_mapping = mapping[range_to_use]
218
    range_perc = percentile_values[range_to_use]
219
    diff_mapping = np.diff(range_mapping)
220
    diff_perc = np.diff(range_perc)
221
222
    # Handling the case where two landmarks are the same
223
    # for a given input image. This usually happens when
224
    # image background is not removed from the image.
225
    diff_perc[diff_perc < epsilon] = np.inf
226
227
    affine_map = np.zeros([2, len(range_to_use) - 1])
228
229
    # Compute slopes of the linear models
230
    affine_map[0] = diff_mapping / diff_perc
231
232
    # Compute intercepts of the linear models
233
    affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
234
235
    bin_id = np.digitize(data, range_perc[1:-1], right=False)
236
    lin_img = affine_map[0, bin_id]
237
    aff_img = affine_map[1, bin_id]
238
    new_img = lin_img * data + aff_img
239
    new_img = new_img.reshape(shape)
240
    new_img = new_img.astype(np.float32)
241
    new_img = torch.from_numpy(new_img)
242
    return new_img
243
244
245
train = train_histogram = HistogramStandardization.train
246