Passed
Pull Request — main (#1278)
by Fernando
01:32
created

torchio.datasets.ct_rate   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 396
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 190
dl 0
loc 396
rs 10
c 0
b 0
f 0
wmc 23

15 Methods

Rating   Name   Duplication   Size   Complexity  
A CtRate._parse_split() 0 21 3
A CtRate._get_image_path() 0 23 1
A CtRate._get_subjects_list() 0 20 1
A CtRate._get_csv() 0 16 1
A CtRate._keep_n_subjects() 0 10 1
A CtRate.__init__() 0 21 2
A CtRate._get_metadata() 0 34 2
A CtRate._extract_report_dict() 0 23 2
A CtRate._instantiate_image() 0 16 1
A CtRate._get_labels() 0 10 1
A CtRate._merge() 0 21 1
A CtRate._get_subject() 0 22 2
A CtRate._get_csv_prefix() 0 16 3
A CtRate._get_reports() 0 10 1
A CtRate._fix_image() 0 38 1
1
from __future__ import annotations
2
3
import ast
4
import multiprocessing
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Literal
8
from typing import Union
9
10
import numpy as np
11
import SimpleITK as sitk
12
from tqdm.contrib.concurrent import thread_map
13
14
from ..data.dataset import SubjectsDataset
15
from ..data.image import ScalarImage
16
from ..data.subject import Subject
17
from ..external.imports import get_pandas
18
from ..types import TypePath
19
20
if TYPE_CHECKING:
21
    import pandas as pd
22
23
24
TypeSplit = Union[
25
    Literal['train'],
26
    Literal['valid'],
27
    Literal['validation'],
28
]
29
30
31
class CtRate(SubjectsDataset):
32
    """CT-RATE dataset.
33
34
    This class provides access to
35
    `CT-RATE <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_,
36
    which contains chest CT scans with associated radiology reports and
37
    abnormality labels.
38
39
    The dataset must have been downloaded previously.
40
41
    Args:
42
        root: Root directory where the dataset has been downloaded.
43
        split: Dataset split to use, either ``'train'`` or ``'validation'``.
44
        token: Hugging Face token for accessing gated repositories. Alternatively,
45
            login using `huggingface-cli login` to cache the token.
46
        num_subjects: Optional limit on the number of subjects to load (useful for
47
            testing). If ``None``, all subjects in the split are loaded.
48
        report_key: Key to use for storing radiology reports in the Subject metadata.
49
        sizes: List of image sizes (in pixels) to include. Default: [512, 768, 1024].
50
        **kwargs: Additional arguments for SubjectsDataset.
51
52
    Examples:
53
        >>> dataset = CtRate('/path/to/data', split='train')
54
    """
55
56
    _REPO_ID = 'ibrahimhamamci/CT-RATE'
57
    _FILENAME_KEY = 'VolumeName'
58
    _SIZES = [512, 768, 1024]
59
    ABNORMALITIES = [
60
        'Medical material',
61
        'Arterial wall calcification',
62
        'Cardiomegaly',
63
        'Pericardial effusion',
64
        'Coronary artery wall calcification',
65
        'Hiatal hernia',
66
        'Lymphadenopathy',
67
        'Emphysema',
68
        'Atelectasis',
69
        'Lung nodule',
70
        'Lung opacity',
71
        'Pulmonary fibrotic sequela',
72
        'Pleural effusion',
73
        'Mosaic attenuation pattern',
74
        'Peribronchial thickening',
75
        'Consolidation',
76
        'Bronchiectasis',
77
        'Interlobular septal thickening',
78
    ]
79
80
    def __init__(
81
        self,
82
        root: TypePath,
83
        split: TypeSplit = 'train',
84
        *,
85
        token: str | None = None,
86
        num_subjects: int | None = None,
87
        report_key: str = 'report',
88
        sizes: list[int] | None = None,
89
        **kwargs,
90
    ):
91
        self._root_dir = Path(root)
92
        self._token = token
93
        self._num_subjects = num_subjects
94
        self._report_key = report_key
95
        self._sizes = self._SIZES if sizes is None else sizes
96
97
        self._split = self._parse_split(split)
98
        self.metadata = self._get_metadata()
99
        subjects_list = self._get_subjects_list(self.metadata)
100
        super().__init__(subjects_list, **kwargs)
101
102
    @staticmethod
103
    def _parse_split(split: str) -> str:
104
        """Normalize the split name.
105
106
        Converts 'validation' to 'valid' and validates that the split name
107
        is one of the allowed values.
108
109
        Args:
110
            split: The split name to parse ('train', 'valid', or 'validation').
111
112
        Returns:
113
            str: Normalized split name ('train' or 'valid').
114
115
        Raises:
116
            ValueError: If the split name is not one of the allowed values.
117
        """
118
        if split in ['valid', 'validation']:
119
            return 'valid'
120
        if split not in ['train', 'valid']:
121
            raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'")
122
        return split
123
124
    def _get_csv(
125
        self,
126
        dirname: str,
127
        filename: str,
128
    ) -> pd.DataFrame:
129
        """Load a CSV file from the specified directory within the dataset.
130
131
        Args:
132
            dirname: Directory name within 'dataset/' where the CSV is located.
133
            filename: Name of the CSV file to load.
134
        """
135
        subfolder = Path(f'dataset/{dirname}')
136
        path = Path(self._root_dir, subfolder, filename)
137
        pd = get_pandas()
138
        table = pd.read_csv(path)
139
        return table
140
141
    def _get_csv_prefix(self, expand_validation: bool = True) -> str:
142
        """Get the prefix for CSV filenames based on the current split.
143
144
        Returns the appropriate prefix for CSV filenames based on the current split.
145
        For the validation split, can either return 'valid' or 'validation' depending
146
        on the expand_validation parameter.
147
148
        Args:
149
            expand_validation: If ``True`` and split is ``'valid'``, return
150
                ``'validation'``. Otherwise, return the split name as is.
151
        """
152
        if expand_validation and self._split == 'valid':
153
            prefix = 'validation'
154
        else:
155
            prefix = self._split
156
        return prefix
157
158
    def _get_metadata(self) -> pd.DataFrame:
159
        """Load and process the dataset metadata.
160
161
        Loads metadata from the appropriate CSV file, filters images by size,
162
        extracts subject, scan, and reconstruction IDs from filenames, and
163
        merges in reports and abnormality labels.
164
        """
165
        dirname = 'metadata'
166
        prefix = self._get_csv_prefix()
167
        filename = f'{prefix}_metadata.csv'
168
        metadata = self._get_csv(dirname, filename)
169
170
        # Exclude images with size not in self._sizes
171
        rows_int = metadata['Rows'].astype(int)
172
        metadata = metadata[rows_int.isin(self._sizes)]
173
174
        index_columns = [
175
            'subject_id',
176
            'scan_id',
177
            'reconstruction_id',
178
        ]
179
        pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz'
180
        metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern)
181
182
        if self._num_subjects is not None:
183
            metadata = self._keep_n_subjects(metadata, self._num_subjects)
184
185
        # Add reports and abnormality labels to metadata, keeping only the rows for the
186
        # images in the metadata table
187
        metadata = self._merge(metadata, self._get_reports())
188
        metadata = self._merge(metadata, self._get_labels())
189
190
        metadata.set_index(index_columns, inplace=True)
191
        return metadata
192
193
    def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 20 is False. Are you sure this can never be the case?
Loading history...
194
        """Merge a new dataframe into the base dataframe using the filename as the key.
195
196
        This method performs a left join between ``base_df`` and ``new_df`` using the
197
        volume filename as the join key, ensuring that all records from ``base_df`` are
198
        preserved while matching data from ``new_df`` is added.
199
200
        Args:
201
            base_df: The primary dataframe to merge into.
202
            new_df: The dataframe containing additional data to be merged.
203
204
        Returns:
205
            pd.DataFrame: The merged dataframe with all rows from base_df and
206
            matching columns from new_df.
207
        """
208
        pd = get_pandas()
209
        return pd.merge(
210
            base_df,
211
            new_df,
212
            on=self._FILENAME_KEY,
213
            how='left',
214
        )
215
216
    def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 20 is False. Are you sure this can never be the case?
Loading history...
217
        """Limit the metadata to the first ``n`` subjects.
218
219
        Args:
220
            metadata: The complete metadata dataframe.
221
            n: Maximum number of subjects to keep.
222
        """
223
        unique_subjects = metadata['subject_id'].unique()
224
        selected_subjects = unique_subjects[:n]
225
        return metadata[metadata['subject_id'].isin(selected_subjects)]
226
227
    def _get_reports(self) -> pd.DataFrame:
228
        """Load the radiology reports associated with the CT scans.
229
230
        Retrieves the CSV file containing radiology reports for the current split
231
        (train or validation).
232
        """
233
        dirname = 'radiology_text_reports'
234
        prefix = self._get_csv_prefix()
235
        filename = f'{prefix}_reports.csv'
236
        return self._get_csv(dirname, filename)
237
238
    def _get_labels(self) -> pd.DataFrame:
239
        """Load the abnormality labels for the CT scans.
240
241
        Retrieves the CSV file containing predicted abnormality labels for the
242
        current split.
243
        """
244
        dirname = 'multi_abnormality_labels'
245
        prefix = self._get_csv_prefix(expand_validation=False)
246
        filename = f'{prefix}_predicted_labels.csv'
247
        return self._get_csv(dirname, filename)
248
249
    def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 20 is False. Are you sure this can never be the case?
Loading history...
250
        """Create a list of Subject instances from the metadata.
251
252
        Processes the metadata to create Subject objects, each containing one or more
253
        CT images. Processing is performed in parallel.
254
255
        Note:
256
            This method uses parallelization to improve performance when creating
257
            multiple Subject instances.
258
        """
259
        df_no_index = metadata.reset_index()
260
        num_subjects = df_no_index['subject_id'].nunique()
261
        iterable = df_no_index.groupby('subject_id')
262
        subjects = thread_map(
263
            self._get_subject,
264
            iterable,
265
            max_workers=multiprocessing.cpu_count(),
266
            total=num_subjects,
267
        )
268
        return subjects
269
270
    def _get_subject(
271
        self,
272
        subject_id_and_metadata: tuple[str, pd.DataFrame],
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 20 is False. Are you sure this can never be the case?
Loading history...
273
    ) -> Subject:
274
        """Create a Subject instance for a specific subject.
275
276
        Processes all images belonging to a single subject and creates a Subject
277
        object containing those images.
278
279
        Args:
280
            subject_id_and_metadata: A tuple containing the subject ID (string) and a
281
                DataFrame containing metadata for all images associated to that subject.
282
        """
283
        subject_id, subject_df = subject_id_and_metadata
284
        subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id}
285
        for _, image_row in subject_df.iterrows():
286
            image = self._instantiate_image(image_row)
287
            scan_id = image_row['scan_id']
288
            reconstruction_id = image_row['reconstruction_id']
289
            image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}'
290
            subject_dict[image_key] = image
291
        return Subject(**subject_dict)  # type: ignore[arg-type]
292
293
    def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 20 is False. Are you sure this can never be the case?
Loading history...
294
        """Create a ScalarImage object for a specific image.
295
296
        Processes a row from the metadata DataFrame to create a ScalarImage object,
297
298
        Args:
299
            image_row: A pandas Series representing a row from the metadata DataFrame,
300
                containing information about a single image.
301
        """
302
        image_dict = image_row.to_dict()
303
        filename = image_dict[self._FILENAME_KEY]
304
        image_path = self._root_dir / self._get_image_path(filename)
305
        report_dict = self._extract_report_dict(image_dict)
306
        image_dict[self._report_key] = report_dict
307
        image = ScalarImage(image_path, **image_dict)
308
        return image
309
310
    def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]:
311
        """Extract radiology report information from the subject dictionary.
312
313
        Extracts the English radiology report components (clinical information,
314
        findings, impressions, and technique) from the subject dictionary and
315
        removes these keys from the original dictionary.
316
317
        Args:
318
            subject_dict: Image metadata including report fields.
319
320
        Note:
321
            This method modifies the input subject_dict by removing the report keys.
322
        """
323
        report_keys = [
324
            'ClinicalInformation_EN',
325
            'Findings_EN',
326
            'Impressions_EN',
327
            'Technique_EN',
328
        ]
329
        report_dict = {}
330
        for key in report_keys:
331
            report_dict[key] = subject_dict.pop(key)
332
        return report_dict
333
334
    @staticmethod
335
    def _get_image_path(filename: str) -> Path:
336
        """Construct the relative path to an image file within the dataset structure.
337
338
        Parses the filename to determine the hierarchical directory structure
339
        where the image is stored in the CT-RATE dataset.
340
341
        Args:
342
            filename: The name of the image file (e.g., 'train_2_a_1.nii.gz').
343
344
        Returns:
345
            Path: The relative path to the image file within the dataset directory.
346
347
        Example:
348
            >>> path = CtRate._get_image_path('train_2_a_1.nii.gz')
349
            # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz')
350
        """
351
        parts = filename.split('_')
352
        base_dir = 'dataset'
353
        split_dir = parts[0]
354
        level1 = f'{parts[0]}_{parts[1]}'
355
        level2 = f'{level1}_{parts[2]}'
356
        return Path(base_dir, split_dir, level1, level2, filename)
357
358
    @staticmethod
359
    def _fix_image(path: Path, metadata: dict[str, str]) -> None:
360
        """Fix the spatial metadata of a CT-RATE image file.
361
362
        The original NIfTI files in the CT-RATE dataset have incorrect spatial
363
        metadata. This method reads the image, fixes the spacing, origin, and
364
        orientation based on the metadata provided in the CSV, and applies the correct
365
        rescaling to convert to Hounsfield units.
366
367
        Args:
368
            path: The path to the image file to fix.
369
            metadata: A dictionary containing image metadata including spacing,
370
                orientation, and rescale parameters.
371
372
        Note:
373
            This method overwrites the original file with the fixed version.
374
            The fixed image is stored as INT16 with proper HU values.
375
        """
376
        # Adapted from https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/download_scripts/fix_metadata.py
377
        image = sitk.ReadImage(str(path))
378
379
        spacing_x, spacing_y = map(float, ast.literal_eval(metadata['XYSpacing']))
380
        spacing_z = metadata['ZSpacing']
381
        image.SetSpacing((spacing_x, spacing_y, spacing_z))
382
383
        image.SetOrigin(ast.literal_eval(metadata['ImagePositionPatient']))
384
385
        orientation = ast.literal_eval(metadata['ImageOrientationPatient'])
386
        row_cosine, col_cosine = orientation[:3], orientation[3:6]
387
        z_cosine = np.cross(row_cosine, col_cosine).tolist()
388
        image.SetDirection(row_cosine + col_cosine + z_cosine)
389
390
        RescaleIntercept = metadata['RescaleIntercept']
391
        RescaleSlope = metadata['RescaleSlope']
392
        adjusted_hu = image * RescaleSlope + RescaleIntercept
393
        cast_int16 = sitk.Cast(adjusted_hu, sitk.sitkInt16)
394
395
        sitk.WriteImage(cast_int16, str(path))
396