Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

unpaired_mr_brain.demo_data   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 196
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 139
dl 0
loc 196
rs 10
c 0
b 0
f 0
1
import os
2
import shutil
3
import tarfile
4
import zipfile
5
from os import listdir, makedirs, remove
6
from os.path import exists, join
7
8
import nibabel as nib
9
import numpy as np
10
from tensorflow.keras.utils import get_file
11
12
##############
13
# Parameters #
14
##############
15
16
data_splits = ["train", "test"]
17
num_labels = 3
18
19
# Main project directory
20
main_path = os.getcwd()
21
os.chdir(main_path)
22
23
# Demo directory
24
project_dir = "demos/unpaired_mr_brain"
25
os.chdir(join(main_path, project_dir))
26
27
# Data storage directory
28
data_folder_name = "dataset"
29
path_to_data_folder = join(main_path, project_dir, data_folder_name)
30
if os.path.exists(path_to_data_folder):
31
    shutil.rmtree(path_to_data_folder)
32
os.mkdir(path_to_data_folder)
33
34
# Pretrained model storage directory
35
model_folder_name = join(project_dir, data_folder_name, "pretrained")
36
path_to_model_folder = join(main_path, model_folder_name)
37
38
#################
39
# Download data #
40
#################
41
# Data
42
FILENAME = "data_mr_brain"
43
ORIGIN = "https://github.com/acasamitjana/Data/raw/master/L2R_Task4_HippocampusMRI.tar"
44
TAR_FILE = FILENAME + ".tar"
45
46
get_file(os.path.abspath(TAR_FILE), ORIGIN)
47
48
if exists(path_to_data_folder) is not True:
49
    makedirs(path_to_data_folder)
50
51
with tarfile.open(join(main_path, project_dir, TAR_FILE), "r") as tar_ref:
52
    tar_ref.extractall(data_folder_name)
53
54
remove(TAR_FILE)
55
print("Files unzipped successfully")
56
57
# Model
58
PRETRAINED_MODEL = "unpaired_mr_brain.zip"
59
URL_MODEL = (
60
    "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/" + PRETRAINED_MODEL
61
)
62
63
get_file(os.path.abspath(PRETRAINED_MODEL), URL_MODEL)
64
65
if exists(path_to_model_folder) is not True:
66
    makedirs(path_to_model_folder)
67
68
with zipfile.ZipFile(join(main_path, project_dir, PRETRAINED_MODEL), "r") as zip_ref:
69
    zip_ref.extractall(path_to_model_folder)
70
71
remove(PRETRAINED_MODEL)
72
print("The file ", PRETRAINED_MODEL, " has successfully been downloaded!")
73
74
##################
75
# Create dataset #
76
##################
77
path_to_init_img = join(path_to_data_folder, "Training", "img")
78
path_to_init_label = join(path_to_data_folder, "Training", "label")
79
80
path_to_train = join(path_to_data_folder, "train")
81
path_to_test = join(path_to_data_folder, "test")
82
83
if not exists(path_to_train):
84
    makedirs(join(path_to_train, "images"))
85
    makedirs(join(path_to_train, "labels"))
86
    makedirs(join(path_to_train, "masks"))
87
else:
88
    shutil.rmtree(path_to_train)
89
    makedirs(join(path_to_train, "images"))
90
    makedirs(join(path_to_train, "labels"))
91
    makedirs(join(path_to_train, "masks"))
92
93
if not exists(path_to_test):
94
    makedirs(join(path_to_test, "images"))
95
    makedirs(join(path_to_test, "labels"))
96
    makedirs(join(path_to_test, "masks"))
97
    shutil.rmtree(path_to_test)
98
    makedirs(join(path_to_test, "images"))
99
    makedirs(join(path_to_test, "labels"))
100
    makedirs(join(path_to_test, "masks"))
101
102
img_files = listdir(path_to_init_img)
103
for f in img_files:
104
    num_subject = int(f.split("_")[1].split(".")[0])
105
106
    if num_subject < 311:
107
        shutil.copy(join(path_to_init_img, f), join(path_to_train, "images"))
108
    else:
109
        shutil.copy(join(path_to_init_img, f), join(path_to_test, "images"))
110
111
img_files = listdir(path_to_init_label)
112
for f in img_files:
113
    num_subject = int(f.split("_")[1].split(".")[0])
114
    if num_subject < 311:
115
        shutil.copy(join(path_to_init_label, f), join(path_to_train, "labels"))
116
    else:
117
        shutil.copy(join(path_to_init_label, f), join(path_to_test, "labels"))
118
119
shutil.rmtree(join(path_to_data_folder, "Training"))
120
print("Files succesfully copied to " + path_to_train + " and " + path_to_test)
121
122
#################
123
# Preprocessing #
124
#################
125
for ds in data_splits:
126
    path = join(path_to_data_folder, ds, "images")
127
    files = listdir(path)
128
    for f in files:
129
        proxy = nib.load(join(path, f))
130
        data = np.asarray(proxy.dataobj)
131
        mask = np.zeros_like(data)
132
        center = [int(s / 2) for s in data.shape]
133
        mask_tuple = []
134
        axes = [2, 0, 1]
135
        for it_dim in range(len(data.shape)):
136
            dim = data.shape[it_dim]
137
            axes = [np.mod(a + 1, 3) for a in axes]
138
            data_tmp = np.transpose(data, axes=axes)
139
140
            it_voxel_init = 0
141
            values_init = data_tmp[it_voxel_init, center[it_dim]]
142
            while True:
143
                it_voxel_init += 1
144
                values = data_tmp[it_voxel_init, center[it_dim]]
145
                if np.sum((values - values_init) ** 2) > 0:
146
                    break
147
148
            it_voxel_fi = dim - 1
149
            values_fi = data_tmp[it_voxel_fi, center[it_dim]]
150
            while True:
151
                it_voxel_fi -= 1
152
                values = data_tmp[it_voxel_fi, center[it_dim]]
153
                if np.sum((values - values_fi) ** 2) > 1:
154
                    it_voxel_fi += 1
155
                    break
156
157
            mask_tuple.append((it_voxel_init, it_voxel_fi))
158
159
        mask[
160
            mask_tuple[0][0] : mask_tuple[0][1],
161
            mask_tuple[1][0] : mask_tuple[1][1],
162
            mask_tuple[2][0] : mask_tuple[2][1],
163
        ] = 1
164
        img = nib.Nifti1Image(mask, affine=proxy.affine)
165
        nib.save(img, join(path_to_data_folder, ds, "masks", f))
166
167
        data = data * mask
168
        M = np.max(data)
169
        m = np.min(data)
170
        if M > 255:
171
            data = (data - m) / (M - m) * 255.0
172
        img = nib.Nifti1Image(data, affine=proxy.affine)
173
        nib.save(img, join(path, f))
174
175
print("Images have been correctly normalized between [0, 255]")
176
177
# One hot encoding labels labels
178
for ds in data_splits:
179
    path = join(path_to_data_folder, ds, "labels")
180
    files = listdir(path)
181
    for f in files:
182
        proxy = nib.load(join(path, f))
183
        labels = np.asarray(proxy.dataobj)
184
        labels_one_hot = []
185
        for it_l in range(1, num_labels):
186
            index_labels = np.where(labels == it_l)
187
            mask = np.zeros_like(labels)
188
            mask[index_labels] = 1
189
            labels_one_hot.append(mask)
190
        labels_one_hot = np.stack(labels_one_hot, axis=-1)
191
        img = nib.Nifti1Image(labels_one_hot, proxy.affine)
192
        nib.save(img, join(path, f))
193
194
print(
195
    "Labels have been one-hot encoding using a total of " + str(num_labels) + " labels."
196
)
197