|
1
|
|
|
from __future__ import annotations |
|
2
|
|
|
|
|
3
|
|
|
from pathlib import Path |
|
4
|
|
|
from typing import Any |
|
5
|
|
|
from typing import Union |
|
6
|
|
|
|
|
7
|
|
|
from ..data import LabelMap |
|
8
|
|
|
from ..data import ScalarImage |
|
9
|
|
|
from ..data import Subject |
|
10
|
|
|
from ..data import SubjectsDataset |
|
11
|
|
|
from ..external.imports import get_pandas |
|
12
|
|
|
from ..types import TypePath |
|
13
|
|
|
from ..utils import normalize_path |
|
14
|
|
|
|
|
15
|
|
|
TypeBoxes = list[dict[str, Union[str, float, int]]] |
|
16
|
|
|
|
|
17
|
|
|
|
|
18
|
|
|
class RSNACervicalSpineFracture(SubjectsDataset): |
|
19
|
|
|
"""RSNA 2022 Cervical Spine Fracture Detection dataset. |
|
20
|
|
|
|
|
21
|
|
|
This is a helper class for the dataset used in the |
|
22
|
|
|
`RSNA 2022 Cervical Spine Fracture Detection`_ hosted on |
|
23
|
|
|
`kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before |
|
24
|
|
|
instantiating this class. |
|
25
|
|
|
|
|
26
|
|
|
.. _RSNA 2022 Cervical Spine Fracture Detection: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/overview/evaluation |
|
27
|
|
|
""" |
|
28
|
|
|
|
|
29
|
|
|
UID = 'StudyInstanceUID' |
|
30
|
|
|
|
|
31
|
|
|
def __init__( |
|
32
|
|
|
self, |
|
33
|
|
|
root_dir: TypePath, |
|
34
|
|
|
add_segmentations: bool = False, |
|
35
|
|
|
add_bounding_boxes: bool = False, |
|
36
|
|
|
**kwargs, |
|
37
|
|
|
): |
|
38
|
|
|
self.root_dir = normalize_path(root_dir) |
|
39
|
|
|
subjects = self._get_subjects( |
|
40
|
|
|
add_segmentations, |
|
41
|
|
|
add_bounding_boxes, |
|
42
|
|
|
) |
|
43
|
|
|
super().__init__(subjects, **kwargs) |
|
44
|
|
|
|
|
45
|
|
|
@staticmethod |
|
46
|
|
|
def _get_image_dirs_dict(images_dir: Path) -> dict[str, Path]: |
|
47
|
|
|
dirs_dict = {} |
|
48
|
|
|
for dicom_dir in sorted(images_dir.iterdir()): |
|
49
|
|
|
dirs_dict[dicom_dir.name] = dicom_dir |
|
50
|
|
|
return dirs_dict |
|
51
|
|
|
|
|
52
|
|
|
@staticmethod |
|
53
|
|
|
def _get_segs_paths_dict(segs_dir: Path) -> dict[str, Path]: |
|
54
|
|
|
paths_dict = {} |
|
55
|
|
|
for image_path in sorted(segs_dir.iterdir()): |
|
56
|
|
|
key = image_path.name.replace('.gz', '').replace('.nii', '') |
|
57
|
|
|
paths_dict[key] = image_path |
|
58
|
|
|
return paths_dict |
|
59
|
|
|
|
|
60
|
|
|
def _get_subjects( |
|
61
|
|
|
self, |
|
62
|
|
|
add_segmentations: bool, |
|
63
|
|
|
add_bounding_boxes: bool, |
|
64
|
|
|
) -> list[Subject]: |
|
65
|
|
|
subjects = [] |
|
66
|
|
|
pd = get_pandas() |
|
67
|
|
|
from tqdm.auto import tqdm |
|
68
|
|
|
|
|
69
|
|
|
split_name = 'train' |
|
70
|
|
|
images_dirname = f'{split_name}_images' |
|
71
|
|
|
images_dir = self.root_dir / images_dirname |
|
72
|
|
|
image_dirs_dict = self._get_image_dirs_dict(images_dir) |
|
73
|
|
|
|
|
74
|
|
|
segmentations_dir = self.root_dir / 'segmentations' |
|
75
|
|
|
seg_paths_dict = self._get_segs_paths_dict(segmentations_dir) |
|
76
|
|
|
|
|
77
|
|
|
bboxes_path = self.root_dir / 'train_bounding_boxes.csv' |
|
78
|
|
|
bounding_boxes_df = pd.read_csv(bboxes_path) |
|
79
|
|
|
grouped_boxes = bounding_boxes_df.groupby(self.UID) |
|
80
|
|
|
|
|
81
|
|
|
df = pd.read_csv(self.root_dir / f'{split_name}.csv') |
|
82
|
|
|
|
|
83
|
|
|
for _, row in tqdm(list(df.iterrows())): |
|
84
|
|
|
uid = row[self.UID] |
|
85
|
|
|
image_dir = image_dirs_dict[uid] |
|
86
|
|
|
seg_path = None |
|
87
|
|
|
if add_segmentations: |
|
88
|
|
|
seg_path = seg_paths_dict.get(uid, None) |
|
89
|
|
|
boxes = [] |
|
90
|
|
|
if add_bounding_boxes: |
|
91
|
|
|
try: |
|
92
|
|
|
boxes_df = grouped_boxes.get_group(uid) |
|
93
|
|
|
boxes = [dict(row) for _, row in boxes_df.iterrows()] |
|
94
|
|
|
except KeyError: |
|
95
|
|
|
pass |
|
96
|
|
|
subject = self._get_subject( |
|
97
|
|
|
dict(row), |
|
98
|
|
|
image_dir, |
|
99
|
|
|
seg_path, |
|
100
|
|
|
boxes, |
|
101
|
|
|
) |
|
102
|
|
|
subjects.append(subject) |
|
103
|
|
|
return subjects |
|
104
|
|
|
|
|
105
|
|
|
@staticmethod |
|
106
|
|
|
def _filter_list(iterable: list[Path], target: str): |
|
107
|
|
|
def _filter(path: Path): |
|
108
|
|
|
if path.is_dir(): |
|
109
|
|
|
return target == path.name |
|
110
|
|
|
else: |
|
111
|
|
|
name = path.name.replace('.gz', '').replace('.nii', '') |
|
112
|
|
|
return target == name |
|
113
|
|
|
|
|
114
|
|
|
found = list(filter(_filter, iterable)) |
|
115
|
|
|
if found: |
|
116
|
|
|
assert len(found) == 1 |
|
117
|
|
|
result = found[0] |
|
118
|
|
|
else: |
|
119
|
|
|
result = None |
|
120
|
|
|
return result |
|
121
|
|
|
|
|
122
|
|
|
def _get_subject( |
|
123
|
|
|
self, |
|
124
|
|
|
csv_row_dict: dict[str, str | int], |
|
125
|
|
|
image_dir: Path, |
|
126
|
|
|
seg_path: Path | None, |
|
127
|
|
|
boxes: TypeBoxes, |
|
128
|
|
|
) -> Subject: |
|
129
|
|
|
subject_dict: dict[str, Any] = {} |
|
130
|
|
|
subject_dict.update(csv_row_dict) |
|
131
|
|
|
subject_dict['ct'] = ScalarImage(image_dir) |
|
132
|
|
|
if seg_path is not None: |
|
133
|
|
|
subject_dict['seg'] = LabelMap(seg_path) |
|
134
|
|
|
if boxes: |
|
135
|
|
|
subject_dict['boxes'] = boxes |
|
136
|
|
|
return Subject(**subject_dict) |
|
137
|
|
|
|