Passed
Pull Request — main (#1278)
by Fernando
01:31
created

torchio.datasets.rsna_spine_fracture.get_pandas()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 0
dl 0
loc 11
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
from pathlib import Path
4
from typing import Any
5
from typing import Union
6
7
from ..data import LabelMap
8
from ..data import ScalarImage
9
from ..data import Subject
10
from ..data import SubjectsDataset
11
from ..external.imports import get_pandas
12
from ..types import TypePath
13
from ..utils import normalize_path
14
15
TypeBoxes = list[dict[str, Union[str, float, int]]]
16
17
18
class RSNACervicalSpineFracture(SubjectsDataset):
19
    """RSNA 2022 Cervical Spine Fracture Detection dataset.
20
21
    This is a helper class for the dataset used in the
22
    `RSNA 2022 Cervical Spine Fracture Detection`_ hosted on
23
    `kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before
24
    instantiating this class.
25
26
    .. _RSNA 2022 Cervical Spine Fracture Detection: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/overview/evaluation
27
    """
28
29
    UID = 'StudyInstanceUID'
30
31
    def __init__(
32
        self,
33
        root_dir: TypePath,
34
        add_segmentations: bool = False,
35
        add_bounding_boxes: bool = False,
36
        **kwargs,
37
    ):
38
        self.root_dir = normalize_path(root_dir)
39
        subjects = self._get_subjects(
40
            add_segmentations,
41
            add_bounding_boxes,
42
        )
43
        super().__init__(subjects, **kwargs)
44
45
    @staticmethod
46
    def _get_image_dirs_dict(images_dir: Path) -> dict[str, Path]:
47
        dirs_dict = {}
48
        for dicom_dir in sorted(images_dir.iterdir()):
49
            dirs_dict[dicom_dir.name] = dicom_dir
50
        return dirs_dict
51
52
    @staticmethod
53
    def _get_segs_paths_dict(segs_dir: Path) -> dict[str, Path]:
54
        paths_dict = {}
55
        for image_path in sorted(segs_dir.iterdir()):
56
            key = image_path.name.replace('.gz', '').replace('.nii', '')
57
            paths_dict[key] = image_path
58
        return paths_dict
59
60
    def _get_subjects(
61
        self,
62
        add_segmentations: bool,
63
        add_bounding_boxes: bool,
64
    ) -> list[Subject]:
65
        subjects = []
66
        pd = get_pandas()
67
        from tqdm.auto import tqdm
68
69
        split_name = 'train'
70
        images_dirname = f'{split_name}_images'
71
        images_dir = self.root_dir / images_dirname
72
        image_dirs_dict = self._get_image_dirs_dict(images_dir)
73
74
        segmentations_dir = self.root_dir / 'segmentations'
75
        seg_paths_dict = self._get_segs_paths_dict(segmentations_dir)
76
77
        bboxes_path = self.root_dir / 'train_bounding_boxes.csv'
78
        bounding_boxes_df = pd.read_csv(bboxes_path)
79
        grouped_boxes = bounding_boxes_df.groupby(self.UID)
80
81
        df = pd.read_csv(self.root_dir / f'{split_name}.csv')
82
83
        for _, row in tqdm(list(df.iterrows())):
84
            uid = row[self.UID]
85
            image_dir = image_dirs_dict[uid]
86
            seg_path = None
87
            if add_segmentations:
88
                seg_path = seg_paths_dict.get(uid, None)
89
            boxes = []
90
            if add_bounding_boxes:
91
                try:
92
                    boxes_df = grouped_boxes.get_group(uid)
93
                    boxes = [dict(row) for _, row in boxes_df.iterrows()]
94
                except KeyError:
95
                    pass
96
            subject = self._get_subject(
97
                dict(row),
98
                image_dir,
99
                seg_path,
100
                boxes,
101
            )
102
            subjects.append(subject)
103
        return subjects
104
105
    @staticmethod
106
    def _filter_list(iterable: list[Path], target: str):
107
        def _filter(path: Path):
108
            if path.is_dir():
109
                return target == path.name
110
            else:
111
                name = path.name.replace('.gz', '').replace('.nii', '')
112
                return target == name
113
114
        found = list(filter(_filter, iterable))
115
        if found:
116
            assert len(found) == 1
117
            result = found[0]
118
        else:
119
            result = None
120
        return result
121
122
    def _get_subject(
123
        self,
124
        csv_row_dict: dict[str, str | int],
125
        image_dir: Path,
126
        seg_path: Path | None,
127
        boxes: TypeBoxes,
128
    ) -> Subject:
129
        subject_dict: dict[str, Any] = {}
130
        subject_dict.update(csv_row_dict)
131
        subject_dict['ct'] = ScalarImage(image_dir)
132
        if seg_path is not None:
133
            subject_dict['seg'] = LabelMap(seg_path)
134
        if boxes:
135
            subject_dict['boxes'] = boxes
136
        return Subject(**subject_dict)
137