|
1
|
|
|
""" |
|
2
|
|
|
Load unpaired data. |
|
3
|
|
|
Supported formats: h5 and Nifti. |
|
4
|
|
|
Image data can be labeled or unlabeled. |
|
5
|
|
|
""" |
|
6
|
|
|
import random |
|
7
|
|
|
from typing import List, Tuple, Union |
|
8
|
|
|
|
|
9
|
|
|
from deepreg.dataset.loader.interface import ( |
|
10
|
|
|
AbstractUnpairedDataLoader, |
|
11
|
|
|
GeneratorDataLoader, |
|
12
|
|
|
) |
|
13
|
|
|
from deepreg.dataset.util import check_difference_between_two_lists |
|
14
|
|
|
from deepreg.registry import REGISTRY |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
@REGISTRY.register_data_loader(name="unpaired") |
|
18
|
|
|
class UnpairedDataLoader(AbstractUnpairedDataLoader, GeneratorDataLoader): |
|
19
|
|
|
""" |
|
20
|
|
|
Load unpaired data using given file loader. Handles both labeled |
|
21
|
|
|
and unlabeled cases. |
|
22
|
|
|
The function sample_index_generator needs to be defined for the |
|
23
|
|
|
GeneratorDataLoader class. |
|
24
|
|
|
""" |
|
25
|
|
|
|
|
26
|
|
|
def __init__( |
|
27
|
|
|
self, |
|
28
|
|
|
file_loader, |
|
29
|
|
|
data_dir_paths: List[str], |
|
30
|
|
|
labeled: bool, |
|
31
|
|
|
sample_label: str, |
|
32
|
|
|
seed: int, |
|
33
|
|
|
image_shape: Union[Tuple[int, ...], List[int]], |
|
34
|
|
|
): |
|
35
|
|
|
""" |
|
36
|
|
|
Load data which are unpaired, labeled or unlabeled. |
|
37
|
|
|
|
|
38
|
|
|
:param file_loader: |
|
39
|
|
|
:param data_dir_paths: paths of the directories storing data, |
|
40
|
|
|
the data are saved under four different sub-directories: images, labels |
|
41
|
|
|
:param labeled: whether the data is labeled. |
|
42
|
|
|
:param sample_label: |
|
43
|
|
|
:param seed: |
|
44
|
|
|
:param image_shape: (width, height, depth) |
|
45
|
|
|
""" |
|
46
|
|
|
super().__init__( |
|
47
|
|
|
image_shape=image_shape, |
|
48
|
|
|
labeled=labeled, |
|
49
|
|
|
sample_label=sample_label, |
|
50
|
|
|
seed=seed, |
|
51
|
|
|
) |
|
52
|
|
|
assert isinstance( |
|
53
|
|
|
data_dir_paths, list |
|
54
|
|
|
), f"data_dir_paths must be list of strings, got {data_dir_paths}" |
|
55
|
|
|
loader_image = file_loader( |
|
56
|
|
|
dir_paths=data_dir_paths, name="images", grouped=False |
|
57
|
|
|
) |
|
58
|
|
|
self.loader_moving_image = loader_image |
|
59
|
|
|
self.loader_fixed_image = loader_image |
|
60
|
|
|
if self.labeled: |
|
61
|
|
|
loader_label = file_loader( |
|
62
|
|
|
dir_paths=data_dir_paths, name="labels", grouped=False |
|
63
|
|
|
) |
|
64
|
|
|
self.loader_moving_label = loader_label |
|
65
|
|
|
self.loader_fixed_label = loader_label |
|
66
|
|
|
self.validate_data_files() |
|
67
|
|
|
|
|
68
|
|
|
self.num_images = self.loader_moving_image.get_num_images() |
|
69
|
|
|
self._num_samples = self.num_images // 2 |
|
70
|
|
|
|
|
71
|
|
|
def validate_data_files(self): |
|
72
|
|
|
""" |
|
73
|
|
|
Verify all loader have the same files. |
|
74
|
|
|
Since fixed and moving loaders come from the same file_loader, |
|
75
|
|
|
there is no need to check both (avoid duplicate). |
|
76
|
|
|
""" |
|
77
|
|
|
if self.labeled: |
|
78
|
|
|
image_ids = self.loader_moving_image.get_data_ids() |
|
79
|
|
|
label_ids = self.loader_moving_label.get_data_ids() |
|
80
|
|
|
check_difference_between_two_lists( |
|
81
|
|
|
list1=image_ids, |
|
82
|
|
|
list2=label_ids, |
|
83
|
|
|
name="images and labels in unpaired loader", |
|
84
|
|
|
) |
|
85
|
|
|
|
|
86
|
|
|
def sample_index_generator(self): |
|
87
|
|
|
""" |
|
88
|
|
|
Generates sample indexes to load data using the |
|
89
|
|
|
GeneratorDataLoader class. |
|
90
|
|
|
""" |
|
91
|
|
|
image_indices = [i for i in range(self.num_images)] |
|
92
|
|
|
random.Random(self.seed).shuffle(image_indices) |
|
93
|
|
|
for sample_index in range(self.num_samples): |
|
94
|
|
|
moving_index, fixed_index = ( |
|
95
|
|
|
image_indices[2 * sample_index], |
|
96
|
|
|
image_indices[2 * sample_index + 1], |
|
97
|
|
|
) |
|
98
|
|
|
yield moving_index, fixed_index, [moving_index, fixed_index] |
|
99
|
|
|
|
|
100
|
|
|
def close(self): |
|
101
|
|
|
""" |
|
102
|
|
|
Close the moving files opened by the file_loaders. |
|
103
|
|
|
""" |
|
104
|
|
|
self.loader_moving_image.close() |
|
105
|
|
|
if self.labeled: |
|
106
|
|
|
self.loader_moving_label.close() |
|
107
|
|
|
|