Passed
Push — master ( 2f1e12...1916ef )
by Fernando
02:30
created

torchio.utils   B

Complexity

Total Complexity 46

Size/Duplication

Total Lines 283
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 46
eloc 187
dl 0
loc 283
rs 8.72
c 0
b 0
f 0

15 Functions

Rating   Name   Duplication   Size   Complexity  
A to_tuple() 0 19 2
B create_dummy_dataset() 0 59 7
A get_stem() 0 11 2
A get_first_item() 0 2 1
A get_major_sitk_version() 0 6 2
A get_torchio_cache_dir() 0 2 1
A compress() 0 4 3
B guess_type() 0 21 6
A get_subjects_from_batch() 0 24 4
A check_sequence() 0 6 2
A get_subclasses() 0 4 1
A get_batch_images_and_size() 0 19 4
A history_collate() 0 14 4
A apply_transform_to_file() 0 14 3
A add_images_from_batch() 0 31 4

How to fix   Complexity   

Complexity

Complex classes like torchio.utils often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import ast
2
import gzip
3
import shutil
4
import tempfile
5
from pathlib import Path
6
from typing import Union, Iterable, Tuple, Any, Optional, List, Sequence, Dict
7
8
import torch
9
from torch.utils.data._utils.collate import default_collate
10
import numpy as np
11
import nibabel as nib
12
import SimpleITK as sitk
13
from tqdm import trange
14
15
from . import constants
16
from .typing import TypeNumber, TypePath
17
18
19
def to_tuple(
20
        value: Union[TypeNumber, Iterable[TypeNumber]],
21
        length: int = 1,
22
        ) -> Tuple[TypeNumber, ...]:
23
    """
24
    to_tuple(1, length=1) -> (1,)
25
    to_tuple(1, length=3) -> (1, 1, 1)
26
27
    If value is an iterable, n is ignored and tuple(value) is returned
28
    to_tuple((1,), length=1) -> (1,)
29
    to_tuple((1, 2), length=1) -> (1, 2)
30
    to_tuple([1, 2], length=3) -> (1, 2)
31
    """
32
    try:
33
        iter(value)
34
        value = tuple(value)
35
    except TypeError:
36
        value = length * (value,)
37
    return value
38
39
40
def get_stem(
41
        path: Union[TypePath, List[TypePath]]
42
        ) -> Union[str, List[str]]:
43
    """
44
    '/home/user/image.nii.gz' -> 'image'
45
    """
46
    def _get_stem(path_string):
47
        return Path(path_string).name.split('.')[0]
48
    if isinstance(path, (str, Path)):
49
        return _get_stem(path)
50
    return [_get_stem(p) for p in path]
51
52
53
def create_dummy_dataset(
54
        num_images: int,
55
        size_range: Tuple[int, int],
56
        directory: Optional[TypePath] = None,
57
        suffix: str = '.nii.gz',
58
        force: bool = False,
59
        verbose: bool = False,
60
        ):
61
    from .data import ScalarImage, LabelMap, Subject
62
    output_dir = tempfile.gettempdir() if directory is None else directory
63
    output_dir = Path(output_dir)
64
    images_dir = output_dir / 'dummy_images'
65
    labels_dir = output_dir / 'dummy_labels'
66
67
    if force:
68
        shutil.rmtree(images_dir)
69
        shutil.rmtree(labels_dir)
70
71
    subjects: List[Subject] = []
72
    if images_dir.is_dir():
73
        for i in trange(num_images):
74
            image_path = images_dir / f'image_{i}{suffix}'
75
            label_path = labels_dir / f'label_{i}{suffix}'
76
            subject = Subject(
77
                one_modality=ScalarImage(image_path),
78
                segmentation=LabelMap(label_path),
79
            )
80
            subjects.append(subject)
81
    else:
82
        images_dir.mkdir(exist_ok=True, parents=True)
83
        labels_dir.mkdir(exist_ok=True, parents=True)
84
        if verbose:
85
            print('Creating dummy dataset...')  # noqa: T001
86
            iterable = trange(num_images)
87
        else:
88
            iterable = range(num_images)
89
        for i in iterable:
90
            shape = np.random.randint(*size_range, size=3)
91
            affine = np.eye(4)
92
            image = np.random.rand(*shape)
93
            label = np.ones_like(image)
94
            label[image < 0.33] = 0
95
            label[image > 0.66] = 2
96
            image *= 255
97
98
            image_path = images_dir / f'image_{i}{suffix}'
99
            nii = nib.Nifti1Image(image.astype(np.uint8), affine)
100
            nii.to_filename(str(image_path))
101
102
            label_path = labels_dir / f'label_{i}{suffix}'
103
            nii = nib.Nifti1Image(label.astype(np.uint8), affine)
104
            nii.to_filename(str(label_path))
105
106
            subject = Subject(
107
                one_modality=ScalarImage(image_path),
108
                segmentation=LabelMap(label_path),
109
            )
110
            subjects.append(subject)
111
    return subjects
112
113
114
def apply_transform_to_file(
115
        input_path: TypePath,
116
        transform,  # : Transform seems to create a circular import
117
        output_path: TypePath,
118
        class_: str = 'ScalarImage',
119
        verbose: bool = False,
120
        ):
121
    from . import data
122
    image = getattr(data, class_)(input_path)
123
    subject = data.Subject(image=image)
124
    transformed = transform(subject)
125
    transformed.image.save(output_path)
126
    if verbose and transformed.history:
127
        print('Applied transform:', transformed.history[0])  # noqa: T001
128
129
130
def guess_type(string: str) -> Any:
131
    # Adapted from
132
    # https://www.reddit.com/r/learnpython/comments/4599hl/module_to_guess_type_from_a_string/czw3f5s
133
    string = string.replace(' ', '')
134
    try:
135
        value = ast.literal_eval(string)
136
    except ValueError:
137
        result_type = str
138
    else:
139
        result_type = type(value)
140
    if result_type in (list, tuple):
141
        string = string[1:-1]  # remove brackets
142
        split = string.split(',')
143
        list_result = [guess_type(n) for n in split]
144
        value = tuple(list_result) if result_type is tuple else list_result
145
        return value
146
    try:
147
        value = result_type(string)
148
    except TypeError:
149
        value = None
150
    return value
151
152
153
def get_torchio_cache_dir() -> Path:
154
    return Path('~/.cache/torchio').expanduser()
155
156
157
def compress(input_path: TypePath, output_path: TypePath) -> None:
158
    with open(input_path, 'rb') as f_in:
159
        with gzip.open(output_path, 'wb') as f_out:
160
            shutil.copyfileobj(f_in, f_out)
161
162
163
def check_sequence(sequence: Sequence, name: str) -> None:
164
    try:
165
        iter(sequence)
166
    except TypeError:
167
        message = f'"{name}" must be a sequence, not {type(name)}'
168
        raise TypeError(message)
169
170
171
def get_major_sitk_version() -> int:
172
    # This attribute was added in version 2
173
    # https://github.com/SimpleITK/SimpleITK/pull/1171
174
    version = getattr(sitk, '__version__', None)
175
    major_version = 1 if version is None else 2
176
    return major_version
177
178
179
def history_collate(batch: Sequence, collate_transforms=True) -> Dict:
180
    attr = constants.HISTORY if collate_transforms else 'applied_transforms'
181
    # Adapted from
182
    # https://github.com/romainVala/torchQC/blob/master/segmentation/collate_functions.py
183
    from .data import Subject
184
    first_element = batch[0]
185
    if isinstance(first_element, Subject):
186
        dictionary = {
187
            key: default_collate([d[key] for d in batch])
188
            for key in first_element
189
        }
190
        if hasattr(first_element, attr):
191
            dictionary.update({attr: [getattr(d, attr) for d in batch]})
192
        return dictionary
193
194
195
def get_subclasses(target_class: type) -> List[type]:
196
    subclasses = target_class.__subclasses__()
197
    subclasses += sum((get_subclasses(cls) for cls in subclasses), [])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cls does not seem to be defined.
Loading history...
198
    return subclasses
199
200
201
def get_first_item(data_loader: torch.utils.data.DataLoader):
202
    return next(iter(data_loader))
203
204
205
def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]:
206
    """Get number of images and images names in a batch.
207
208
    Args:
209
        batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
210
        extracting data from a :class:`torchio.SubjectsDataset`.
211
212
    Raises:
213
        RuntimeError: If the batch does not seem to contain any dictionaries
214
        that seem to represent a :class:`torchio.Image`.
215
    """
216
    names = []
217
    for image_name, image_dict in batch.items():
218
        if constants.DATA in image_dict:  # assume it is a TorchIO Image
219
            size = len(image_dict[constants.DATA])
220
            names.append(image_name)
221
    if not names:
222
        raise RuntimeError('The batch does not seem to contain any images')
223
    return names, size
0 ignored issues
show
introduced by
The variable size does not seem to be defined for all execution paths.
Loading history...
224
225
226
def get_subjects_from_batch(batch: Dict) -> List:
227
    """Get list of subjects from collated batch.
228
229
    Args:
230
        batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
231
        extracting data from a :class:`torchio.SubjectsDataset`.
232
    """
233
    from .data import ScalarImage, LabelMap, Subject
234
    subjects = []
235
    image_names, batch_size = get_batch_images_and_size(batch)
236
    for i in range(batch_size):
237
        subject_dict = {}
238
        for image_name in image_names:
239
            image_dict = batch[image_name]
240
            data = image_dict[constants.DATA][i]
241
            affine = image_dict[constants.AFFINE][i]
242
            path = Path(image_dict[constants.PATH][i])
243
            is_label = image_dict[constants.TYPE] == constants.LABEL
244
            klass = LabelMap if is_label else ScalarImage
245
            image = klass(tensor=data, affine=affine, filename=path.name)
246
            subject_dict[image_name] = image
247
        subject = Subject(subject_dict)
248
        subjects.append(subject)
249
    return subjects
250
251
252
def add_images_from_batch(
253
        subjects: List,
254
        tensor: torch.Tensor,
255
        class_=None,
256
        name='prediction',
257
        ) -> None:
258
    """Add images to subjects in a list, typically from a network prediction.
259
260
    The spatial metadata (affine matrices) will be extracted from one of the
261
    images of each subject.
262
263
    Args:
264
        subjects: List of instances of :class:`torchio.Subject` to which images
265
            will be added.
266
        tensor: PyTorch tensor of shape :math:`(B, C, W, H, D)`, where
267
            :math:`B` is the batch size.
268
        class_: Class used to instantiate the images,
269
            e.g., :class:`torchio.LabelMap`.
270
            If ``None``, :class:`torchio.ScalarImage` will be used.
271
        name: Name of the images added to the subjects.
272
    """
273
    if class_ is None:
274
        from . import ScalarImage
275
        class_ = ScalarImage
276
    for subject, data in zip(subjects, tensor):
277
        one_image = subject.get_first_image()
278
        kwargs = {'tensor': data, 'affine': one_image.affine}
279
        if 'filename' in one_image:
280
            kwargs['filename'] = one_image['filename']
281
        image = class_(**kwargs)
282
        subject.add_image(image, name)
283