|
1
|
|
|
from __future__ import annotations |
|
2
|
|
|
|
|
3
|
|
|
import ast |
|
4
|
|
|
import enum |
|
5
|
|
|
import multiprocessing |
|
6
|
|
|
import shutil |
|
7
|
|
|
from pathlib import Path |
|
8
|
|
|
from typing import TYPE_CHECKING |
|
9
|
|
|
from typing import Literal |
|
10
|
|
|
from typing import Union |
|
11
|
|
|
|
|
12
|
|
|
import numpy as np |
|
13
|
|
|
import SimpleITK as sitk |
|
14
|
|
|
from tqdm.auto import tqdm |
|
15
|
|
|
from tqdm.contrib.concurrent import process_map |
|
16
|
|
|
from tqdm.contrib.concurrent import thread_map |
|
17
|
|
|
|
|
18
|
|
|
from ..data.dataset import SubjectsDataset |
|
19
|
|
|
from ..data.image import ScalarImage |
|
20
|
|
|
from ..data.subject import Subject |
|
21
|
|
|
from ..external.imports import get_pandas |
|
22
|
|
|
from ..types import TypePath |
|
23
|
|
|
|
|
24
|
|
|
if TYPE_CHECKING: |
|
25
|
|
|
import pandas as pd |
|
26
|
|
|
|
|
27
|
|
|
|
|
28
|
|
|
TypeSplit = Union[ |
|
29
|
|
|
Literal['train'], |
|
30
|
|
|
Literal['valid'], |
|
31
|
|
|
Literal['validation'], |
|
32
|
|
|
] |
|
33
|
|
|
|
|
34
|
|
|
TypeParallelism = Literal['thread', 'process', None] |
|
35
|
|
|
|
|
36
|
|
|
|
|
37
|
|
|
class MetadataIndexColumn(str, enum.Enum): |
|
38
|
|
|
SUBJECT_ID = 'subject_id' |
|
39
|
|
|
SCAN_ID = 'scan_id' |
|
40
|
|
|
RECONSTRUCTION_ID = 'reconstruction_id' |
|
41
|
|
|
|
|
42
|
|
|
|
|
43
|
|
|
class CtRate(SubjectsDataset): |
|
44
|
|
|
"""CT-RATE dataset. |
|
45
|
|
|
|
|
46
|
|
|
This class provides access to |
|
47
|
|
|
`CT-RATE <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_, |
|
48
|
|
|
which contains chest CT scans with associated radiology reports and |
|
49
|
|
|
abnormality labels. |
|
50
|
|
|
|
|
51
|
|
|
The dataset must have been downloaded previously. |
|
52
|
|
|
|
|
53
|
|
|
Args: |
|
54
|
|
|
root: Root directory where the dataset has been downloaded. |
|
55
|
|
|
split: Dataset split to use, either ``'train'`` or ``'validation'``. |
|
56
|
|
|
token: Hugging Face token for accessing gated repositories. Alternatively, |
|
57
|
|
|
login using `huggingface-cli login` to cache the token. |
|
58
|
|
|
num_subjects: Optional limit on the number of subjects to load (useful for |
|
59
|
|
|
testing). If ``None``, all subjects in the split are loaded. |
|
60
|
|
|
report_key: Key to use for storing radiology reports in the Subject metadata. |
|
61
|
|
|
sizes: List of image sizes (in pixels) to include. Default: [512, 768, 1024]. |
|
62
|
|
|
**kwargs: Additional arguments for SubjectsDataset. |
|
63
|
|
|
|
|
64
|
|
|
Examples: |
|
65
|
|
|
>>> dataset = CtRate('/path/to/data', split='train') |
|
66
|
|
|
""" |
|
67
|
|
|
|
|
68
|
|
|
_REPO_ID = 'ibrahimhamamci/CT-RATE' |
|
69
|
|
|
_FILENAME_KEY = 'VolumeName' |
|
70
|
|
|
_SIZES = [512, 768, 1024] |
|
71
|
|
|
ABNORMALITIES = [ |
|
72
|
|
|
'Medical material', |
|
73
|
|
|
'Arterial wall calcification', |
|
74
|
|
|
'Cardiomegaly', |
|
75
|
|
|
'Pericardial effusion', |
|
76
|
|
|
'Coronary artery wall calcification', |
|
77
|
|
|
'Hiatal hernia', |
|
78
|
|
|
'Lymphadenopathy', |
|
79
|
|
|
'Emphysema', |
|
80
|
|
|
'Atelectasis', |
|
81
|
|
|
'Lung nodule', |
|
82
|
|
|
'Lung opacity', |
|
83
|
|
|
'Pulmonary fibrotic sequela', |
|
84
|
|
|
'Pleural effusion', |
|
85
|
|
|
'Mosaic attenuation pattern', |
|
86
|
|
|
'Peribronchial thickening', |
|
87
|
|
|
'Consolidation', |
|
88
|
|
|
'Bronchiectasis', |
|
89
|
|
|
'Interlobular septal thickening', |
|
90
|
|
|
] |
|
91
|
|
|
|
|
92
|
|
|
def __init__( |
|
93
|
|
|
self, |
|
94
|
|
|
root: TypePath, |
|
95
|
|
|
split: TypeSplit = 'train', |
|
96
|
|
|
*, |
|
97
|
|
|
token: str | None = None, |
|
98
|
|
|
num_subjects: int | None = None, |
|
99
|
|
|
report_key: str = 'report', |
|
100
|
|
|
sizes: list[int] | None = None, |
|
101
|
|
|
**kwargs, |
|
102
|
|
|
): |
|
103
|
|
|
self._root_dir = Path(root) |
|
104
|
|
|
self._token = token |
|
105
|
|
|
self._num_subjects = num_subjects |
|
106
|
|
|
self._report_key = report_key |
|
107
|
|
|
self._sizes = self._SIZES if sizes is None else sizes |
|
108
|
|
|
|
|
109
|
|
|
self._split = self._parse_split(split) |
|
110
|
|
|
self.metadata = self._get_metadata() |
|
111
|
|
|
subjects_list = self._get_subjects_list(self.metadata) |
|
112
|
|
|
super().__init__(subjects_list, **kwargs) |
|
113
|
|
|
|
|
114
|
|
|
@staticmethod |
|
115
|
|
|
def _parse_split(split: str) -> str: |
|
116
|
|
|
"""Normalize the split name. |
|
117
|
|
|
|
|
118
|
|
|
Converts 'validation' to 'valid' and validates that the split name |
|
119
|
|
|
is one of the allowed values. |
|
120
|
|
|
|
|
121
|
|
|
Args: |
|
122
|
|
|
split: The split name to parse ('train', 'valid', or 'validation'). |
|
123
|
|
|
|
|
124
|
|
|
Returns: |
|
125
|
|
|
str: Normalized split name ('train' or 'valid'). |
|
126
|
|
|
|
|
127
|
|
|
Raises: |
|
128
|
|
|
ValueError: If the split name is not one of the allowed values. |
|
129
|
|
|
""" |
|
130
|
|
|
if split in ['valid', 'validation']: |
|
131
|
|
|
return 'valid' |
|
132
|
|
|
if split not in ['train', 'valid']: |
|
133
|
|
|
raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'") |
|
134
|
|
|
return split |
|
135
|
|
|
|
|
136
|
|
|
def _get_csv( |
|
137
|
|
|
self, |
|
138
|
|
|
dirname: str, |
|
139
|
|
|
filename: str, |
|
140
|
|
|
) -> pd.DataFrame: |
|
141
|
|
|
"""Load a CSV file from the specified directory within the dataset. |
|
142
|
|
|
|
|
143
|
|
|
Args: |
|
144
|
|
|
dirname: Directory name within 'dataset/' where the CSV is located. |
|
145
|
|
|
filename: Name of the CSV file to load. |
|
146
|
|
|
""" |
|
147
|
|
|
subfolder = Path(f'dataset/{dirname}') |
|
148
|
|
|
path = Path(self._root_dir, subfolder, filename) |
|
149
|
|
|
pd = get_pandas() |
|
150
|
|
|
table = pd.read_csv(path) |
|
151
|
|
|
return table |
|
152
|
|
|
|
|
153
|
|
|
def _get_csv_prefix(self, expand_validation: bool = True) -> str: |
|
154
|
|
|
"""Get the prefix for CSV filenames based on the current split. |
|
155
|
|
|
|
|
156
|
|
|
Returns the appropriate prefix for CSV filenames based on the current split. |
|
157
|
|
|
For the validation split, can either return 'valid' or 'validation' depending |
|
158
|
|
|
on the expand_validation parameter. |
|
159
|
|
|
|
|
160
|
|
|
Args: |
|
161
|
|
|
expand_validation: If ``True`` and split is ``'valid'``, return |
|
162
|
|
|
``'validation'``. Otherwise, return the split name as is. |
|
163
|
|
|
""" |
|
164
|
|
|
if expand_validation and self._split == 'valid': |
|
165
|
|
|
prefix = 'validation' |
|
166
|
|
|
else: |
|
167
|
|
|
prefix = self._split |
|
168
|
|
|
return prefix |
|
169
|
|
|
|
|
170
|
|
|
def _get_metadata(self) -> pd.DataFrame: |
|
171
|
|
|
"""Load and process the dataset metadata. |
|
172
|
|
|
|
|
173
|
|
|
Loads metadata from the appropriate CSV file, filters images by size, |
|
174
|
|
|
extracts subject, scan, and reconstruction IDs from filenames, and |
|
175
|
|
|
merges in reports and abnormality labels. |
|
176
|
|
|
""" |
|
177
|
|
|
dirname = 'metadata' |
|
178
|
|
|
prefix = self._get_csv_prefix() |
|
179
|
|
|
filename = f'{prefix}_metadata.csv' |
|
180
|
|
|
metadata = self._get_csv(dirname, filename) |
|
181
|
|
|
|
|
182
|
|
|
# Exclude images with size not in self._sizes |
|
183
|
|
|
rows_int = metadata['Rows'].astype(int) |
|
184
|
|
|
metadata = metadata[rows_int.isin(self._sizes)] |
|
185
|
|
|
|
|
186
|
|
|
index_columns = [ |
|
187
|
|
|
MetadataIndexColumn.SUBJECT_ID.value, |
|
188
|
|
|
MetadataIndexColumn.SCAN_ID.value, |
|
189
|
|
|
MetadataIndexColumn.RECONSTRUCTION_ID.value, |
|
190
|
|
|
] |
|
191
|
|
|
pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz' |
|
192
|
|
|
metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern) |
|
193
|
|
|
|
|
194
|
|
|
if self._num_subjects is not None: |
|
195
|
|
|
metadata = self._keep_n_subjects(metadata, self._num_subjects) |
|
196
|
|
|
|
|
197
|
|
|
# Add reports and abnormality labels to metadata, keeping only the rows for the |
|
198
|
|
|
# images in the metadata table |
|
199
|
|
|
metadata = self._merge(metadata, self._get_reports()) |
|
200
|
|
|
metadata = self._merge(metadata, self._get_labels()) |
|
201
|
|
|
|
|
202
|
|
|
metadata.set_index(index_columns, inplace=True) |
|
203
|
|
|
return metadata |
|
204
|
|
|
|
|
205
|
|
|
def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame: |
|
|
|
|
|
|
206
|
|
|
"""Merge a new dataframe into the base dataframe using the filename as the key. |
|
207
|
|
|
|
|
208
|
|
|
This method performs a left join between ``base_df`` and ``new_df`` using the |
|
209
|
|
|
volume filename as the join key, ensuring that all records from ``base_df`` are |
|
210
|
|
|
preserved while matching data from ``new_df`` is added. |
|
211
|
|
|
|
|
212
|
|
|
Args: |
|
213
|
|
|
base_df: The primary dataframe to merge into. |
|
214
|
|
|
new_df: The dataframe containing additional data to be merged. |
|
215
|
|
|
|
|
216
|
|
|
Returns: |
|
217
|
|
|
pd.DataFrame: The merged dataframe with all rows from base_df and |
|
218
|
|
|
matching columns from new_df. |
|
219
|
|
|
""" |
|
220
|
|
|
pd = get_pandas() |
|
221
|
|
|
return pd.merge( |
|
222
|
|
|
base_df, |
|
223
|
|
|
new_df, |
|
224
|
|
|
on=self._FILENAME_KEY, |
|
225
|
|
|
how='left', |
|
226
|
|
|
) |
|
227
|
|
|
|
|
228
|
|
|
def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame: |
|
|
|
|
|
|
229
|
|
|
"""Limit the metadata to the first ``n`` subjects. |
|
230
|
|
|
|
|
231
|
|
|
Args: |
|
232
|
|
|
metadata: The complete metadata dataframe. |
|
233
|
|
|
n: Maximum number of subjects to keep. |
|
234
|
|
|
""" |
|
235
|
|
|
unique_subjects = metadata['subject_id'].unique() |
|
236
|
|
|
selected_subjects = unique_subjects[:n] |
|
237
|
|
|
return metadata[metadata['subject_id'].isin(selected_subjects)] |
|
238
|
|
|
|
|
239
|
|
|
def _get_reports(self) -> pd.DataFrame: |
|
240
|
|
|
"""Load the radiology reports associated with the CT scans. |
|
241
|
|
|
|
|
242
|
|
|
Retrieves the CSV file containing radiology reports for the current split |
|
243
|
|
|
(train or validation). |
|
244
|
|
|
""" |
|
245
|
|
|
dirname = 'radiology_text_reports' |
|
246
|
|
|
prefix = self._get_csv_prefix() |
|
247
|
|
|
filename = f'{prefix}_reports.csv' |
|
248
|
|
|
return self._get_csv(dirname, filename) |
|
249
|
|
|
|
|
250
|
|
|
def _get_labels(self) -> pd.DataFrame: |
|
251
|
|
|
"""Load the abnormality labels for the CT scans. |
|
252
|
|
|
|
|
253
|
|
|
Retrieves the CSV file containing predicted abnormality labels for the |
|
254
|
|
|
current split. |
|
255
|
|
|
""" |
|
256
|
|
|
dirname = 'multi_abnormality_labels' |
|
257
|
|
|
prefix = self._get_csv_prefix(expand_validation=False) |
|
258
|
|
|
filename = f'{prefix}_predicted_labels.csv' |
|
259
|
|
|
return self._get_csv(dirname, filename) |
|
260
|
|
|
|
|
261
|
|
|
def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]: |
|
|
|
|
|
|
262
|
|
|
"""Create a list of Subject instances from the metadata. |
|
263
|
|
|
|
|
264
|
|
|
Processes the metadata to create Subject objects, each containing one or more |
|
265
|
|
|
CT images. Processing is performed in parallel. |
|
266
|
|
|
|
|
267
|
|
|
Note: |
|
268
|
|
|
This method uses parallelization to improve performance when creating |
|
269
|
|
|
multiple Subject instances. |
|
270
|
|
|
""" |
|
271
|
|
|
df_no_index = metadata.reset_index() |
|
272
|
|
|
num_subjects = df_no_index['subject_id'].nunique() |
|
273
|
|
|
iterable = df_no_index.groupby('subject_id') |
|
274
|
|
|
subjects = thread_map( |
|
275
|
|
|
self._get_subject, |
|
276
|
|
|
iterable, |
|
277
|
|
|
max_workers=multiprocessing.cpu_count(), |
|
278
|
|
|
total=num_subjects, |
|
279
|
|
|
) |
|
280
|
|
|
return subjects |
|
281
|
|
|
|
|
282
|
|
|
def _get_subject( |
|
283
|
|
|
self, |
|
284
|
|
|
subject_id_and_metadata: tuple[str, pd.DataFrame], |
|
|
|
|
|
|
285
|
|
|
) -> Subject: |
|
286
|
|
|
"""Create a Subject instance for a specific subject. |
|
287
|
|
|
|
|
288
|
|
|
Processes all images belonging to a single subject and creates a Subject |
|
289
|
|
|
object containing those images. |
|
290
|
|
|
|
|
291
|
|
|
Args: |
|
292
|
|
|
subject_id_and_metadata: A tuple containing the subject ID (string) and a |
|
293
|
|
|
DataFrame containing metadata for all images associated to that subject. |
|
294
|
|
|
""" |
|
295
|
|
|
subject_id, subject_df = subject_id_and_metadata |
|
296
|
|
|
subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id} |
|
297
|
|
|
for _, image_row in subject_df.iterrows(): |
|
298
|
|
|
image = self._instantiate_image(image_row) |
|
299
|
|
|
scan_id = image_row['scan_id'] |
|
300
|
|
|
reconstruction_id = image_row['reconstruction_id'] |
|
301
|
|
|
image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}' |
|
302
|
|
|
subject_dict[image_key] = image |
|
303
|
|
|
return Subject(**subject_dict) # type: ignore[arg-type] |
|
304
|
|
|
|
|
305
|
|
|
def _instantiate_image(self, image_row: pd.Series) -> ScalarImage: |
|
|
|
|
|
|
306
|
|
|
"""Create a ScalarImage object for a specific image. |
|
307
|
|
|
|
|
308
|
|
|
Processes a row from the metadata DataFrame to create a ScalarImage object, |
|
309
|
|
|
|
|
310
|
|
|
Args: |
|
311
|
|
|
image_row: A pandas Series representing a row from the metadata DataFrame, |
|
312
|
|
|
containing information about a single image. |
|
313
|
|
|
""" |
|
314
|
|
|
image_dict = image_row.to_dict() |
|
315
|
|
|
filename = image_dict[self._FILENAME_KEY] |
|
316
|
|
|
image_path = self._root_dir / self._get_image_path(filename) |
|
317
|
|
|
report_dict = self._extract_report_dict(image_dict) |
|
318
|
|
|
image_dict[self._report_key] = report_dict |
|
319
|
|
|
image = ScalarImage(image_path, **image_dict) |
|
320
|
|
|
return image |
|
321
|
|
|
|
|
322
|
|
|
def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]: |
|
323
|
|
|
"""Extract radiology report information from the subject dictionary. |
|
324
|
|
|
|
|
325
|
|
|
Extracts the English radiology report components (clinical information, |
|
326
|
|
|
findings, impressions, and technique) from the subject dictionary and |
|
327
|
|
|
removes these keys from the original dictionary. |
|
328
|
|
|
|
|
329
|
|
|
Args: |
|
330
|
|
|
subject_dict: Image metadata including report fields. |
|
331
|
|
|
|
|
332
|
|
|
Note: |
|
333
|
|
|
This method modifies the input subject_dict by removing the report keys. |
|
334
|
|
|
""" |
|
335
|
|
|
report_keys = [ |
|
336
|
|
|
'ClinicalInformation_EN', |
|
337
|
|
|
'Findings_EN', |
|
338
|
|
|
'Impressions_EN', |
|
339
|
|
|
'Technique_EN', |
|
340
|
|
|
] |
|
341
|
|
|
report_dict = {} |
|
342
|
|
|
for key in report_keys: |
|
343
|
|
|
report_dict[key] = subject_dict.pop(key) |
|
344
|
|
|
return report_dict |
|
345
|
|
|
|
|
346
|
|
|
@staticmethod |
|
347
|
|
|
def _get_image_path(filename: str) -> Path: |
|
348
|
|
|
"""Construct the relative path to an image file within the dataset structure. |
|
349
|
|
|
|
|
350
|
|
|
Parses the filename to determine the hierarchical directory structure |
|
351
|
|
|
where the image is stored in the CT-RATE dataset. |
|
352
|
|
|
|
|
353
|
|
|
Args: |
|
354
|
|
|
filename: The name of the image file (e.g., 'train_2_a_1.nii.gz'). |
|
355
|
|
|
|
|
356
|
|
|
Returns: |
|
357
|
|
|
Path: The relative path to the image file within the dataset directory. |
|
358
|
|
|
|
|
359
|
|
|
Example: |
|
360
|
|
|
>>> path = CtRate._get_image_path('train_2_a_1.nii.gz') |
|
361
|
|
|
# Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz') |
|
362
|
|
|
""" |
|
363
|
|
|
parts = filename.split('_') |
|
364
|
|
|
base_dir = 'dataset' |
|
365
|
|
|
split_dir = parts[0] |
|
366
|
|
|
level1 = f'{parts[0]}_{parts[1]}' |
|
367
|
|
|
level2 = f'{level1}_{parts[2]}' |
|
368
|
|
|
return Path(base_dir, split_dir, level1, level2, filename) |
|
369
|
|
|
|
|
370
|
|
|
@staticmethod |
|
371
|
|
|
def _fix_image(image: ScalarImage, out_path: Path, *, force: bool = False) -> None: |
|
372
|
|
|
"""Fix the spatial metadata of a CT-RATE image file. |
|
373
|
|
|
|
|
374
|
|
|
The original NIfTI files in the CT-RATE dataset have incorrect spatial |
|
375
|
|
|
metadata. This method reads the image, fixes the spacing, origin, and |
|
376
|
|
|
orientation based on the metadata provided in the CSV, and applies the correct |
|
377
|
|
|
rescaling to convert to Hounsfield units. |
|
378
|
|
|
|
|
379
|
|
|
Args: |
|
380
|
|
|
in_path: The path to the image file to fix. |
|
381
|
|
|
out_path: The path where the fixed image will be saved. |
|
382
|
|
|
|
|
383
|
|
|
Note: |
|
384
|
|
|
This method overwrites the original file with the fixed version. |
|
385
|
|
|
The fixed image is stored as INT16 with proper HU values. |
|
386
|
|
|
""" |
|
387
|
|
|
# Adapted from https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/download_scripts/fix_metadata.py |
|
388
|
|
|
if not force and out_path.exists(): |
|
389
|
|
|
return |
|
390
|
|
|
spacing_x, spacing_y = map(float, ast.literal_eval(image['XYSpacing'])) |
|
391
|
|
|
spacing_z = image['ZSpacing'] |
|
392
|
|
|
image_sitk = sitk.ReadImage(str(image.path)) |
|
393
|
|
|
image_sitk.SetSpacing((spacing_x, spacing_y, spacing_z)) |
|
394
|
|
|
|
|
395
|
|
|
image_sitk.SetOrigin(ast.literal_eval(image['ImagePositionPatient'])) |
|
396
|
|
|
|
|
397
|
|
|
orientation = ast.literal_eval(image['ImageOrientationPatient']) |
|
398
|
|
|
row_cosine, col_cosine = orientation[:3], orientation[3:6] |
|
399
|
|
|
z_cosine = np.cross(row_cosine, col_cosine).tolist() |
|
400
|
|
|
image_sitk.SetDirection(row_cosine + col_cosine + z_cosine) |
|
401
|
|
|
|
|
402
|
|
|
RescaleIntercept = image['RescaleIntercept'] |
|
403
|
|
|
RescaleSlope = image['RescaleSlope'] |
|
404
|
|
|
adjusted_hu = image_sitk * RescaleSlope + RescaleIntercept |
|
405
|
|
|
cast_int16 = sitk.Cast(adjusted_hu, sitk.sitkInt16) |
|
406
|
|
|
|
|
407
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
408
|
|
|
sitk.WriteImage(cast_int16, str(out_path)) |
|
409
|
|
|
return cast_int16 |
|
410
|
|
|
|
|
411
|
|
|
def _copy_not_images(self, out_dir: Path) -> None: |
|
412
|
|
|
"""Copy all files from the root directory except the images.""" |
|
413
|
|
|
for path in self._root_dir.iterdir(): |
|
414
|
|
|
if path.name == "dataset": |
|
415
|
|
|
for subdirectory in path.iterdir(): |
|
416
|
|
|
if subdirectory.name in ['train', 'valid']: |
|
417
|
|
|
continue |
|
418
|
|
|
print(f'Copying {subdirectory} to {out_dir / subdirectory.relative_to(self._root_dir)}') |
|
419
|
|
|
shutil.copytree( |
|
420
|
|
|
subdirectory, |
|
421
|
|
|
out_dir / subdirectory.relative_to(self._root_dir), |
|
422
|
|
|
dirs_exist_ok=True, |
|
423
|
|
|
) |
|
424
|
|
|
elif path.name.startswith('.'): |
|
425
|
|
|
continue |
|
426
|
|
|
elif path.is_dir(): |
|
427
|
|
|
print(f'Copying {path} to {out_dir / path.name}') |
|
428
|
|
|
shutil.copytree( |
|
429
|
|
|
path, |
|
430
|
|
|
out_dir / path.name, |
|
431
|
|
|
dirs_exist_ok=True, |
|
432
|
|
|
) |
|
433
|
|
|
else: |
|
434
|
|
|
print(f'Copying {path} to {out_dir / path.name}') |
|
435
|
|
|
shutil.copy(path, out_dir / path.name) |
|
436
|
|
|
|
|
437
|
|
|
def fix_metadata( |
|
438
|
|
|
self, |
|
439
|
|
|
out_dir: str | Path, |
|
440
|
|
|
parallelism: TypeParallelism = None, |
|
441
|
|
|
) -> CtRate: |
|
442
|
|
|
"""Fix the metadata of all images in the dataset. |
|
443
|
|
|
|
|
444
|
|
|
Reads each image, applies the correct spatial metadata, and saves the fixed |
|
445
|
|
|
image to the specified output directory. |
|
446
|
|
|
|
|
447
|
|
|
Args: |
|
448
|
|
|
out_dir: The directory where the fixed images will be saved. |
|
449
|
|
|
""" |
|
450
|
|
|
out_dir = Path(out_dir) |
|
451
|
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
452
|
|
|
# self._copy_not_images(out_dir) |
|
453
|
|
|
images = [] |
|
454
|
|
|
out_paths = [] |
|
455
|
|
|
for subject in self.dry_iter(): |
|
456
|
|
|
for image in subject.get_images(): |
|
457
|
|
|
out_path = out_dir / image.path.relative_to(self._root_dir) |
|
458
|
|
|
images.append(image) |
|
459
|
|
|
out_paths.append(out_path) |
|
460
|
|
|
if parallelism == 'thread': |
|
461
|
|
|
thread_map( |
|
462
|
|
|
self._fix_image, |
|
463
|
|
|
images, |
|
464
|
|
|
out_paths, |
|
465
|
|
|
max_workers=multiprocessing.cpu_count(), |
|
466
|
|
|
desc='Fixing metadata', |
|
467
|
|
|
) |
|
468
|
|
|
elif parallelism == 'process': |
|
469
|
|
|
process_map( |
|
470
|
|
|
self._fix_image, |
|
471
|
|
|
images, |
|
472
|
|
|
out_paths, |
|
473
|
|
|
max_workers=multiprocessing.cpu_count(), |
|
474
|
|
|
desc='Fixing metadata', |
|
475
|
|
|
) |
|
476
|
|
|
else: |
|
477
|
|
|
zipped = zip(images, out_paths) |
|
478
|
|
|
with tqdm(total=len(images), desc='Fixing metadata') as pbar: |
|
479
|
|
|
for image, out_path in zipped: |
|
480
|
|
|
pbar.set_description(f'Fixing {image.path.name}') |
|
481
|
|
|
self._fix_image(image, out_path) |
|
482
|
|
|
pbar.update(1) |
|
483
|
|
|
new_dataset = CtRate( |
|
484
|
|
|
out_dir, |
|
485
|
|
|
split=self._split, |
|
486
|
|
|
token=self._token, |
|
487
|
|
|
num_subjects=self._num_subjects, |
|
488
|
|
|
report_key=self._report_key, |
|
489
|
|
|
sizes=self._sizes, |
|
490
|
|
|
) |
|
491
|
|
|
return new_dataset |
|
492
|
|
|
|