Conditions | 11 |
Total Lines | 101 |
Code Lines | 43 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like torchio.transforms.preprocessing.intensity.histogram_standardization.HistogramStandardization.train() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | from pathlib import Path |
||
95 | @classmethod |
||
96 | def train( |
||
97 | cls, |
||
98 | images_paths: Sequence[TypePath], |
||
99 | cutoff: Optional[Tuple[float, float]] = None, |
||
100 | mask_path: Optional[Union[Sequence[TypePath], TypePath]] = None, |
||
101 | masking_function: Optional[Callable] = None, |
||
102 | output_path: Optional[TypePath] = None, |
||
103 | ) -> np.ndarray: |
||
104 | """Extract average histogram landmarks from images used for training. |
||
105 | |||
106 | Args: |
||
107 | images_paths: List of image paths used to train. |
||
108 | cutoff: Optional minimum and maximum quantile values, |
||
109 | respectively, that are used to select a range of intensity of |
||
110 | interest. Equivalent to :math:`pc_1` and :math:`pc_2` in |
||
111 | `Nyúl and Udupa's paper <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf>`_. |
||
112 | mask_path: Path (or list of paths) to a binary image that will be |
||
113 | used to select the voxels use to compute the stats during |
||
114 | histogram training. If ``None``, all voxels in the image will |
||
115 | be used. |
||
116 | masking_function: Function used to extract voxels used for |
||
117 | histogram training. |
||
118 | output_path: Optional file path with extension ``.txt`` or |
||
119 | ``.npy``, where the landmarks will be saved. |
||
120 | |||
121 | Example: |
||
122 | |||
123 | >>> import torch |
||
124 | >>> import numpy as np |
||
125 | >>> from pathlib import Path |
||
126 | >>> from torchio.transforms import HistogramStandardization |
||
127 | >>> |
||
128 | >>> t1_paths = ['subject_a_t1.nii', 'subject_b_t1.nii.gz'] |
||
129 | >>> t2_paths = ['subject_a_t2.nii', 'subject_b_t2.nii.gz'] |
||
130 | >>> |
||
131 | >>> t1_landmarks_path = Path('t1_landmarks.npy') |
||
132 | >>> t2_landmarks_path = Path('t2_landmarks.npy') |
||
133 | >>> |
||
134 | >>> t1_landmarks = ( |
||
135 | ... t1_landmarks_path |
||
136 | ... if t1_landmarks_path.is_file() |
||
137 | ... else HistogramStandardization.train(t1_paths) |
||
138 | ... ) |
||
139 | >>> torch.save(t1_landmarks, t1_landmarks_path) |
||
140 | >>> |
||
141 | >>> t2_landmarks = ( |
||
142 | ... t2_landmarks_path |
||
143 | ... if t2_landmarks_path.is_file() |
||
144 | ... else HistogramStandardization.train(t2_paths) |
||
145 | ... ) |
||
146 | >>> torch.save(t2_landmarks, t2_landmarks_path) |
||
147 | >>> |
||
148 | >>> landmarks_dict = { |
||
149 | ... 't1': t1_landmarks, |
||
150 | ... 't2': t2_landmarks, |
||
151 | ... } |
||
152 | >>> |
||
153 | >>> transform = HistogramStandardization(landmarks_dict) |
||
154 | """ # noqa: E501 |
||
155 | is_masks_list = isinstance(mask_path, Sequence) |
||
156 | if is_masks_list and len(mask_path) != len(images_paths): |
||
157 | message = ( |
||
158 | f'Different number of images ({len(images_paths)})' |
||
159 | f' and mask ({len(mask_path)}) paths found' |
||
160 | ) |
||
161 | raise ValueError(message) |
||
162 | quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff |
||
163 | percentiles_cutoff = 100 * np.array(quantiles_cutoff) |
||
164 | percentiles_database = [] |
||
165 | percentiles = _get_percentiles(percentiles_cutoff) |
||
166 | for i, image_file_path in enumerate(tqdm(images_paths)): |
||
167 | tensor, _ = read_image(image_file_path) |
||
168 | if masking_function is not None: |
||
169 | mask = masking_function(tensor) |
||
170 | else: |
||
171 | if mask_path is None: |
||
172 | mask = np.ones_like(tensor, dtype=bool) |
||
173 | else: |
||
174 | if is_masks_list: |
||
175 | path = mask_path[i] |
||
176 | else: |
||
177 | path = mask_path |
||
178 | mask, _ = read_image(path) |
||
179 | mask = mask.numpy() > 0 |
||
180 | array = tensor.numpy() |
||
181 | percentile_values = np.percentile(array[mask], percentiles) |
||
182 | percentiles_database.append(percentile_values) |
||
183 | percentiles_database = np.vstack(percentiles_database) |
||
184 | mapping = _get_average_mapping(percentiles_database) |
||
185 | |||
186 | if output_path is not None: |
||
187 | output_path = Path(output_path).expanduser() |
||
188 | extension = output_path.suffix |
||
189 | if extension == '.txt': |
||
190 | modality = 'image' |
||
191 | text = f'{modality} {" ".join(map(str, mapping))}' |
||
192 | output_path.write_text(text) |
||
193 | elif extension == '.npy': |
||
194 | np.save(output_path, mapping) |
||
195 | return mapping |
||
196 | |||
296 |