Passed
Pull Request — main (#1278)
by Fernando
01:27
created

torchio.datasets.ct_rate   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 371
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 182
dl 0
loc 371
rs 10
c 0
b 0
f 0
wmc 23

14 Methods

Rating   Name   Duplication   Size   Complexity  
A CtRate._parse_split() 0 21 3
A CtRate._get_image_path() 0 25 2
A CtRate._get_subjects_list() 0 20 1
A CtRate._get_csv() 0 16 1
A CtRate._keep_n_subjects() 0 10 1
A CtRate.__init__() 0 21 2
A CtRate._get_metadata() 0 34 2
A CtRate._extract_report_dict() 0 17 2
A CtRate._instantiate_image() 0 20 1
A CtRate._get_labels() 0 10 1
A CtRate._get_subject() 0 22 2
A CtRate._merge() 0 21 1
A CtRate._get_csv_prefix() 0 16 3
A CtRate._get_reports() 0 10 1
1
from __future__ import annotations
2
3
import enum
4
import multiprocessing
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Literal
8
from typing import Union
9
10
from tqdm.contrib.concurrent import thread_map
11
12
from ..data.dataset import SubjectsDataset
13
from ..data.image import ScalarImage
14
from ..data.subject import Subject
15
from ..external.imports import get_pandas
16
from ..types import TypePath
17
18
if TYPE_CHECKING:
19
    import pandas as pd
20
21
22
TypeSplit = Union[
23
    Literal['train'],
24
    Literal['valid'],
25
    Literal['validation'],
26
]
27
28
TypeParallelism = Literal['thread', 'process', None]
29
30
31
class MetadataIndexColumn(str, enum.Enum):
32
    SUBJECT_ID = 'subject_id'
33
    SCAN_ID = 'scan_id'
34
    RECONSTRUCTION_ID = 'reconstruction_id'
35
36
37
class CtRate(SubjectsDataset):
38
    """CT-RATE dataset.
39
40
    This class helps loading the `CT-RATE dataset
41
    <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_,
42
    which contains chest CT scans with associated radiology reports and
43
    abnormality labels.
44
45
    The dataset must have been downloaded previously.
46
47
    Args:
48
        root: Root directory where the dataset has been downloaded.
49
        split: Dataset split to use, either ``'train'`` or ``'validation'``.
50
        num_subjects: Optional limit on the number of subjects to load (useful for
51
            debugging). If ``None``, all subjects in the split are loaded.
52
        report_key: Key to use for storing radiology reports in the Subject metadata.
53
        sizes: List of image sizes (in-plane, in voxels) to include.
54
        load_fixed: If ``True``, load the files with fixed spatial metadata
55
            added in `this pull request
56
            <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/discussions/85>`_.
57
            Otherwise, load the original files with incorrect spatial metadata.
58
        **kwargs: Additional arguments for SubjectsDataset.
59
60
    Examples:
61
        >>> dataset = CtRate('/path/to/data', split='train')
62
    """
63
64
    _REPO_ID = 'ibrahimhamamci/CT-RATE'
65
    _FILENAME_KEY = 'VolumeName'
66
    _SIZES = [512, 768, 1024]
67
    ABNORMALITIES = [
68
        'Medical material',
69
        'Arterial wall calcification',
70
        'Cardiomegaly',
71
        'Pericardial effusion',
72
        'Coronary artery wall calcification',
73
        'Hiatal hernia',
74
        'Lymphadenopathy',
75
        'Emphysema',
76
        'Atelectasis',
77
        'Lung nodule',
78
        'Lung opacity',
79
        'Pulmonary fibrotic sequela',
80
        'Pleural effusion',
81
        'Mosaic attenuation pattern',
82
        'Peribronchial thickening',
83
        'Consolidation',
84
        'Bronchiectasis',
85
        'Interlobular septal thickening',
86
    ]
87
    REPORT_KEYS = [
88
        'ClinicalInformation_EN',
89
        'Findings_EN',
90
        'Impressions_EN',
91
        'Technique_EN',
92
    ]
93
94
    def __init__(
95
        self,
96
        root: TypePath,
97
        split: TypeSplit = 'train',
98
        *,
99
        num_subjects: int | None = None,
100
        report_key: str = 'report',
101
        sizes: list[int] | None = None,
102
        load_fixed: bool = True,
103
        **kwargs,
104
    ):
105
        self._root_dir = Path(root)
106
        self._num_subjects = num_subjects
107
        self._report_key = report_key
108
        self._sizes = self._SIZES if sizes is None else sizes
109
110
        self._split = self._parse_split(split)
111
        self.metadata = self._get_metadata()
112
        self._load_fixed = load_fixed
113
        subjects_list = self._get_subjects_list(self.metadata)
114
        super().__init__(subjects_list, **kwargs)
115
116
    @staticmethod
117
    def _parse_split(split: str) -> str:
118
        """Normalize the split name.
119
120
        Converts 'validation' to 'valid' and validates that the split name
121
        is one of the allowed values.
122
123
        Args:
124
            split: The split name to parse ('train', 'valid', or 'validation').
125
126
        Returns:
127
            str: Normalized split name ('train' or 'valid').
128
129
        Raises:
130
            ValueError: If the split name is not one of the allowed values.
131
        """
132
        if split in ['valid', 'validation']:
133
            return 'valid'
134
        if split not in ['train', 'valid']:
135
            raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'")
136
        return split
137
138
    def _get_csv(
139
        self,
140
        dirname: str,
141
        filename: str,
142
    ) -> pd.DataFrame:
143
        """Load a CSV file from the specified directory within the dataset.
144
145
        Args:
146
            dirname: Directory name within 'dataset/' where the CSV is located.
147
            filename: Name of the CSV file to load.
148
        """
149
        subfolder = Path(f'dataset/{dirname}')
150
        path = Path(self._root_dir, subfolder, filename)
151
        pd = get_pandas()
152
        table = pd.read_csv(path)
153
        return table
154
155
    def _get_csv_prefix(self, expand_validation: bool = True) -> str:
156
        """Get the prefix for CSV filenames based on the current split.
157
158
        Returns the appropriate prefix for CSV filenames based on the current split.
159
        For the validation split, can either return 'valid' or 'validation' depending
160
        on the expand_validation parameter.
161
162
        Args:
163
            expand_validation: If ``True`` and split is ``'valid'``, return
164
                ``'validation'``. Otherwise, return the split name as is.
165
        """
166
        if expand_validation and self._split == 'valid':
167
            prefix = 'validation'
168
        else:
169
            prefix = self._split
170
        return prefix
171
172
    def _get_metadata(self) -> pd.DataFrame:
173
        """Load and process the dataset metadata.
174
175
        Loads metadata from the appropriate CSV file, filters images by size,
176
        extracts subject, scan, and reconstruction IDs from filenames, and
177
        merges in reports and abnormality labels.
178
        """
179
        dirname = 'metadata'
180
        prefix = self._get_csv_prefix()
181
        filename = f'{prefix}_metadata.csv'
182
        metadata = self._get_csv(dirname, filename)
183
184
        # Exclude images with size not in self._sizes
185
        rows_int = metadata['Rows'].astype(int)
186
        metadata = metadata[rows_int.isin(self._sizes)]
187
188
        index_columns = [
189
            MetadataIndexColumn.SUBJECT_ID.value,
190
            MetadataIndexColumn.SCAN_ID.value,
191
            MetadataIndexColumn.RECONSTRUCTION_ID.value,
192
        ]
193
        pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz'
194
        metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern)
195
196
        if self._num_subjects is not None:
197
            metadata = self._keep_n_subjects(metadata, self._num_subjects)
198
199
        # Add reports and abnormality labels to metadata, keeping only the rows for the
200
        # images in the metadata table
201
        metadata = self._merge(metadata, self._get_reports())
202
        metadata = self._merge(metadata, self._get_labels())
203
204
        metadata.set_index(index_columns, inplace=True)
205
        return metadata
206
207
    def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 18 is False. Are you sure this can never be the case?
Loading history...
208
        """Merge a new dataframe into the base dataframe using the filename as the key.
209
210
        This method performs a left join between ``base_df`` and ``new_df`` using the
211
        volume filename as the join key, ensuring that all records from ``base_df`` are
212
        preserved while matching data from ``new_df`` is added.
213
214
        Args:
215
            base_df: The primary dataframe to merge into.
216
            new_df: The dataframe containing additional data to be merged.
217
218
        Returns:
219
            pd.DataFrame: The merged dataframe with all rows from base_df and
220
            matching columns from new_df.
221
        """
222
        pd = get_pandas()
223
        return pd.merge(
224
            base_df,
225
            new_df,
226
            on=self._FILENAME_KEY,
227
            how='left',
228
        )
229
230
    def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 18 is False. Are you sure this can never be the case?
Loading history...
231
        """Limit the metadata to the first ``n`` subjects.
232
233
        Args:
234
            metadata: The complete metadata dataframe.
235
            n: Maximum number of subjects to keep.
236
        """
237
        unique_subjects = metadata['subject_id'].unique()
238
        selected_subjects = unique_subjects[:n]
239
        return metadata[metadata['subject_id'].isin(selected_subjects)]
240
241
    def _get_reports(self) -> pd.DataFrame:
242
        """Load the radiology reports associated with the CT scans.
243
244
        Retrieves the CSV file containing radiology reports for the current split
245
        (train or validation).
246
        """
247
        dirname = 'radiology_text_reports'
248
        prefix = self._get_csv_prefix()
249
        filename = f'{prefix}_reports.csv'
250
        return self._get_csv(dirname, filename)
251
252
    def _get_labels(self) -> pd.DataFrame:
253
        """Load the abnormality labels for the CT scans.
254
255
        Retrieves the CSV file containing predicted abnormality labels for the
256
        current split.
257
        """
258
        dirname = 'multi_abnormality_labels'
259
        prefix = self._get_csv_prefix(expand_validation=False)
260
        filename = f'{prefix}_predicted_labels.csv'
261
        return self._get_csv(dirname, filename)
262
263
    def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 18 is False. Are you sure this can never be the case?
Loading history...
264
        """Create a list of Subject instances from the metadata.
265
266
        Processes the metadata to create Subject objects, each containing one or more
267
        CT images. Processing is performed in parallel.
268
269
        Note:
270
            This method uses parallelization to improve performance when creating
271
            multiple Subject instances.
272
        """
273
        df_no_index = metadata.reset_index()
274
        num_subjects = df_no_index['subject_id'].nunique()
275
        iterable = df_no_index.groupby('subject_id')
276
        subjects = thread_map(
277
            self._get_subject,
278
            iterable,
279
            max_workers=multiprocessing.cpu_count(),
280
            total=num_subjects,
281
        )
282
        return subjects
283
284
    def _get_subject(
285
        self,
286
        subject_id_and_metadata: tuple[str, pd.DataFrame],
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 18 is False. Are you sure this can never be the case?
Loading history...
287
    ) -> Subject:
288
        """Create a Subject instance for a specific subject.
289
290
        Processes all images belonging to a single subject and creates a Subject
291
        object containing those images.
292
293
        Args:
294
            subject_id_and_metadata: A tuple containing the subject ID (string) and a
295
                DataFrame containing metadata for all images associated to that subject.
296
        """
297
        subject_id, subject_df = subject_id_and_metadata
298
        subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id}
299
        for _, image_row in subject_df.iterrows():
300
            image = self._instantiate_image(image_row)
301
            scan_id = image_row['scan_id']
302
            reconstruction_id = image_row['reconstruction_id']
303
            image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}'
304
            subject_dict[image_key] = image
305
        return Subject(**subject_dict)  # type: ignore[arg-type]
306
307
    def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 18 is False. Are you sure this can never be the case?
Loading history...
308
        """Create a ScalarImage object for a specific image.
309
310
        Processes a row from the metadata DataFrame to create a ScalarImage object,
311
312
        Args:
313
            image_row: A pandas Series representing a row from the metadata DataFrame,
314
                containing information about a single image.
315
        """
316
        image_dict = image_row.to_dict()
317
        filename = image_dict[self._FILENAME_KEY]
318
        relative_image_path = self._get_image_path(
319
            filename,
320
            load_fixed=self._load_fixed,
321
        )
322
        image_path = self._root_dir / relative_image_path
323
        report_dict = self._extract_report_dict(image_dict)
324
        image_dict[self._report_key] = report_dict
325
        image = ScalarImage(image_path, **image_dict)
326
        return image
327
328
    def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]:
329
        """Extract radiology report information from the subject dictionary.
330
331
        Extracts the English radiology report components (clinical information,
332
        findings, impressions, and technique) from the subject dictionary and
333
        removes these keys from the original dictionary.
334
335
        Args:
336
            subject_dict: Image metadata including report fields.
337
338
        Note:
339
            This method modifies the input subject_dict by removing the report keys.
340
        """
341
        report_dict = {}
342
        for key in self.REPORT_KEYS:
343
            report_dict[key] = subject_dict.pop(key)
344
        return report_dict
345
346
    @staticmethod
347
    def _get_image_path(filename: str, load_fixed: bool) -> Path:
348
        """Construct the relative path to an image file within the dataset structure.
349
350
        Parses the filename to determine the hierarchical directory structure
351
        where the image is stored in the CT-RATE dataset.
352
353
        Args:
354
            filename: The name of the image file (e.g., 'train_2_a_1.nii.gz').
355
356
        Returns:
357
            Path: The relative path to the image file within the dataset directory.
358
359
        Example:
360
            >>> path = CtRate._get_image_path('train_2_a_1.nii.gz')
361
            # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz')
362
        """
363
        parts = filename.split('_')
364
        base_dir = 'dataset'
365
        split_dir = parts[0]
366
        if load_fixed:
367
            split_dir = f'{split_dir}_fixed'
368
        level1 = f'{parts[0]}_{parts[1]}'
369
        level2 = f'{level1}_{parts[2]}'
370
        return Path(base_dir, split_dir, level1, level2, filename)
371