Passed
Pull Request — main (#1278)
by Fernando
01:31
created

torchio.datasets.ct_rate   A

Complexity

Total Complexity 29

Size/Duplication

Total Lines 470
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 219
dl 0
loc 470
rs 10
c 0
b 0
f 0
wmc 29

17 Methods

Rating   Name   Duplication   Size   Complexity  
A CtRate._parse_split() 0 21 3
A CtRate._get_image_path() 0 23 1
A CtRate._get_subjects_list() 0 20 1
A CtRate._download_file() 0 36 2
A CtRate._get_csv() 0 21 2
A CtRate._keep_n_subjects() 0 10 1
A CtRate.__init__() 0 23 2
A CtRate._download_file_if_needed() 0 17 2
A CtRate._get_metadata() 0 34 2
A CtRate._extract_report_dict() 0 23 2
A CtRate._instantiate_image() 0 24 2
A CtRate._get_labels() 0 10 1
A CtRate._merge() 0 21 1
A CtRate._get_subject() 0 22 2
A CtRate._get_csv_prefix() 0 16 3
A CtRate._get_reports() 0 10 1
A CtRate._fix_image() 0 38 1
1
from __future__ import annotations
2
3
import ast
4
import multiprocessing
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Literal
8
from typing import Optional
9
from typing import Union
10
11
import numpy as np
12
import SimpleITK as sitk
13
from tqdm.contrib.concurrent import thread_map
14
15
from ..data.dataset import SubjectsDataset
16
from ..data.image import ScalarImage
17
from ..data.subject import Subject
18
from ..external.imports import get_huggingface_hub
19
from ..external.imports import get_pandas
20
from ..types import TypePath
21
22
if TYPE_CHECKING:
23
    import pandas as pd
24
25
26
TypeSplit = Union[
27
    Literal['train'],
28
    Literal['valid'],
29
    Literal['validation'],
30
]
31
32
33
class CtRate(SubjectsDataset):
34
    """CT-RATE dataset.
35
36
    This class provides access to
37
    `CT-RATE <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_,
38
    which contains chest CT scans with associated radiology reports and
39
    abnormality labels. The dataset can be automatically downloaded from
40
    Hugging Face if needed.
41
42
    Args:
43
        root: Root directory where the dataset is stored or will be downloaded to.
44
        split: Dataset split to use, either ``'train'`` or ``'validation'``.
45
        token: Hugging Face token for accessing gated repositories. Alternatively,
46
            login using `huggingface-cli login` to cache the token.
47
        download: If True, download the dataset if files are not found locally.
48
        num_subjects: Optional limit on the number of subjects to load (useful for
49
            testing). If ``None``, all subjects in the split are loaded.
50
        report_key: Key to use for storing radiology reports in the Subject metadata.
51
        sizes: List of image sizes (in pixels) to include. Default: [512, 768, 1024].
52
        **kwargs: Additional arguments for SubjectsDataset.
53
54
    Examples:
55
        >>> dataset = CtRate('/path/to/data', split='train', download=True)
56
    """
57
58
    _REPO_ID = 'ibrahimhamamci/CT-RATE'
59
    _FILENAME_KEY = 'VolumeName'
60
    _SIZES = [512, 768, 1024]
61
    ABNORMALITIES = [
62
        'Medical material',
63
        'Arterial wall calcification',
64
        'Cardiomegaly',
65
        'Pericardial effusion',
66
        'Coronary artery wall calcification',
67
        'Hiatal hernia',
68
        'Lymphadenopathy',
69
        'Emphysema',
70
        'Atelectasis',
71
        'Lung nodule',
72
        'Lung opacity',
73
        'Pulmonary fibrotic sequela',
74
        'Pleural effusion',
75
        'Mosaic attenuation pattern',
76
        'Peribronchial thickening',
77
        'Consolidation',
78
        'Bronchiectasis',
79
        'Interlobular septal thickening',
80
    ]
81
82
    def __init__(
83
        self,
84
        root: TypePath,
85
        split: TypeSplit = 'train',
86
        *,
87
        token: str | None = None,
88
        download: bool = False,
89
        num_subjects: int | None = None,
90
        report_key: str = 'report',
91
        sizes: list[int] | None = None,
92
        **kwargs,
93
    ):
94
        self._root_dir = Path(root)
95
        self._token = token
96
        self._download = download
97
        self._num_subjects = num_subjects
98
        self._report_key = report_key
99
        self._sizes = self._SIZES if sizes is None else sizes
100
101
        self._split = self._parse_split(split)
102
        self.metadata = self._get_metadata()
103
        subjects_list = self._get_subjects_list(self.metadata)
104
        super().__init__(subjects_list, **kwargs)
105
106
    @staticmethod
107
    def _parse_split(split: str) -> str:
108
        """Normalize the split name.
109
110
        Converts 'validation' to 'valid' and validates that the split name
111
        is one of the allowed values.
112
113
        Args:
114
            split: The split name to parse ('train', 'valid', or 'validation').
115
116
        Returns:
117
            str: Normalized split name ('train' or 'valid').
118
119
        Raises:
120
            ValueError: If the split name is not one of the allowed values.
121
        """
122
        if split in ['valid', 'validation']:
123
            return 'valid'
124
        if split not in ['train', 'valid']:
125
            raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'")
126
        return split
127
128
    def _get_csv(
129
        self,
130
        dirname: str,
131
        filename: str,
132
    ) -> pd.DataFrame:
133
        """Download (if needed) and load a CSV file from the dataset.
134
135
        Load a CSV file from the specified directory within the dataset.
136
        If the file doesn't exist and download is enabled, download it.
137
138
        Args:
139
            dirname: Directory name within 'dataset/' where the CSV is located.
140
            filename: Name of the CSV file to load.
141
        """
142
        subfolder = Path(f'dataset/{dirname}')
143
        path = Path(self._root_dir, subfolder, filename)
144
        if not path.exists():
145
            self._download_file_if_needed(path)
146
        pd = get_pandas()
147
        table = pd.read_csv(path)
148
        return table
149
150
    def _get_csv_prefix(self, expand_validation: bool = True) -> str:
151
        """Get the prefix for CSV filenames based on the current split.
152
153
        Returns the appropriate prefix for CSV filenames based on the current split.
154
        For the validation split, can either return 'valid' or 'validation' depending
155
        on the expand_validation parameter.
156
157
        Args:
158
            expand_validation: If ``True`` and split is ``'valid'``, return
159
                ``'validation'``. Otherwise, return the split name as is.
160
        """
161
        if expand_validation and self._split == 'valid':
162
            prefix = 'validation'
163
        else:
164
            prefix = self._split
165
        return prefix
166
167
    def _get_metadata(self) -> pd.DataFrame:
168
        """Load and process the dataset metadata.
169
170
        Loads metadata from the appropriate CSV file, filters images by size,
171
        extracts subject, scan, and reconstruction IDs from filenames, and
172
        merges in reports and abnormality labels.
173
        """
174
        dirname = 'metadata'
175
        prefix = self._get_csv_prefix()
176
        filename = f'{prefix}_metadata.csv'
177
        metadata = self._get_csv(dirname, filename)
178
179
        # Exclude images with size not in self._sizes
180
        rows_int = metadata['Rows'].astype(int)
181
        metadata = metadata[rows_int.isin(self._sizes)]
182
183
        index_columns = [
184
            'subject_id',
185
            'scan_id',
186
            'reconstruction_id',
187
        ]
188
        pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz'
189
        metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern)
190
191
        if self._num_subjects is not None:
192
            metadata = self._keep_n_subjects(metadata, self._num_subjects)
193
194
        # Add reports and abnormality labels to metadata, keeping only the rows for the
195
        # images in the metadata table
196
        metadata = self._merge(metadata, self._get_reports())
197
        metadata = self._merge(metadata, self._get_labels())
198
199
        metadata.set_index(index_columns, inplace=True)
200
        return metadata
201
202
    def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 22 is False. Are you sure this can never be the case?
Loading history...
203
        """Merge a new dataframe into the base dataframe using the filename as the key.
204
205
        This method performs a left join between ``base_df`` and ``new_df`` using the
206
        volume filename as the join key, ensuring that all records from ``base_df`` are
207
        preserved while matching data from ``new_df`` is added.
208
209
        Args:
210
            base_df: The primary dataframe to merge into.
211
            new_df: The dataframe containing additional data to be merged.
212
213
        Returns:
214
            pd.DataFrame: The merged dataframe with all rows from base_df and
215
            matching columns from new_df.
216
        """
217
        pd = get_pandas()
218
        return pd.merge(
219
            base_df,
220
            new_df,
221
            on=self._FILENAME_KEY,
222
            how='left',
223
        )
224
225
    def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 22 is False. Are you sure this can never be the case?
Loading history...
226
        """Limit the metadata to the first ``n`` subjects.
227
228
        Args:
229
            metadata: The complete metadata dataframe.
230
            n: Maximum number of subjects to keep.
231
        """
232
        unique_subjects = metadata['subject_id'].unique()
233
        selected_subjects = unique_subjects[:n]
234
        return metadata[metadata['subject_id'].isin(selected_subjects)]
235
236
    def _get_reports(self) -> pd.DataFrame:
237
        """Load the radiology reports associated with the CT scans.
238
239
        Retrieves the CSV file containing radiology reports for the current split
240
        (train or validation).
241
        """
242
        dirname = 'radiology_text_reports'
243
        prefix = self._get_csv_prefix()
244
        filename = f'{prefix}_reports.csv'
245
        return self._get_csv(dirname, filename)
246
247
    def _get_labels(self) -> pd.DataFrame:
248
        """Load the abnormality labels for the CT scans.
249
250
        Retrieves the CSV file containing predicted abnormality labels for the
251
        current split.
252
        """
253
        dirname = 'multi_abnormality_labels'
254
        prefix = self._get_csv_prefix(expand_validation=False)
255
        filename = f'{prefix}_predicted_labels.csv'
256
        return self._get_csv(dirname, filename)
257
258
    def _download_file_if_needed(self, path: Path) -> None:
259
        """Download a file if it does not exist locally and ``download`` is enabled.
260
261
        Checks if the specified file exists at the given path. If not, and ``download``
262
        is enabled, it downloads the file; otherwise, it raises an error.
263
264
        Args:
265
            path: The local file path to check and potentially download to.
266
267
        Raises:
268
            FileNotFoundError: If the file doesn't exist and download=False.
269
        """
270
        if self._download:
271
            self._download_file(path)
272
        else:
273
            raise FileNotFoundError(
274
                f'File "{path}" not found.'
275
                " Set 'download=True' to download the dataset"
276
            )
277
278
    def _download_file(self, path: Path) -> None:
279
        """Download a file from the Hugging Face repository to the specified path.
280
281
        Downloads a single file from the CT-RATE dataset repository on Hugging Face,
282
        preserving the directory structure.
283
284
        Args:
285
            path: The destination path where the file will be saved.
286
287
        Raises:
288
            RuntimeError: If the repository is gated and no valid token is provided.
289
290
        Note:
291
            This method requires access to the CT-RATE repository. If the repository
292
            is gated, a valid Hugging Face token with appropriate permissions must
293
            be provided, or the user must log in using the Hugging Face CLI.
294
        """
295
        relative_path = path.relative_to(self._root_dir)
296
        huggingface_hub = get_huggingface_hub()
297
        try:
298
            huggingface_hub.hf_hub_download(
299
                repo_id=self._REPO_ID,
300
                repo_type='dataset',
301
                token=self._token,
302
                subfolder=str(relative_path.parent),
303
                filename=relative_path.name,
304
                local_dir=self._root_dir,
305
            )
306
        except huggingface_hub.errors.GatedRepoError as e:
307
            message = (
308
                f'The dataset "{self._REPO_ID}" is gated. Visit'
309
                f' https://huggingface.co/datasets/{self._REPO_ID}, accept the'
310
                ' terms and conditions, and log in or create and pass a token to'
311
                ' the `token` argument'
312
            )
313
            raise RuntimeError(message) from e
314
315
    def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 22 is False. Are you sure this can never be the case?
Loading history...
316
        """Create a list of Subject instances from the metadata.
317
318
        Processes the metadata to create Subject objects, each containing one or more
319
        CT images. Processing is performed in parallel.
320
321
        Note:
322
            This method uses parallelization to improve performance when creating
323
            multiple Subject instances.
324
        """
325
        df_no_index = metadata.reset_index()
326
        num_subjects = df_no_index['subject_id'].nunique()
327
        iterable = df_no_index.groupby('subject_id')
328
        subjects = thread_map(
329
            self._get_subject,
330
            iterable,
331
            max_workers=multiprocessing.cpu_count(),
332
            total=num_subjects,
333
        )
334
        return subjects
335
336
    def _get_subject(
337
        self,
338
        subject_id_and_metadata: tuple[str, pd.DataFrame],
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 22 is False. Are you sure this can never be the case?
Loading history...
339
    ) -> Subject:
340
        """Create a Subject instance for a specific subject.
341
342
        Processes all images belonging to a single subject and creates a Subject
343
        object containing those images.
344
345
        Args:
346
            subject_id_and_metadata: A tuple containing the subject ID (string) and a
347
                DataFrame containing metadata for all images associated to that subject.
348
        """
349
        subject_id, subject_df = subject_id_and_metadata
350
        subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id}
351
        for _, image_row in subject_df.iterrows():
352
            image = self._instantiate_image(image_row)
353
            scan_id = image_row['scan_id']
354
            reconstruction_id = image_row['reconstruction_id']
355
            image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}'
356
            subject_dict[image_key] = image
357
        return Subject(**subject_dict)  # type: ignore[arg-type]
358
359
    def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 22 is False. Are you sure this can never be the case?
Loading history...
360
        """Create a ScalarImage object for a specific image.
361
362
        Processes a row from the metadata DataFrame to create a ScalarImage object,
363
        downloading the image if necessary and extracting the radiology report.
364
365
        Args:
366
            image_row: A pandas Series representing a row from the metadata DataFrame,
367
                containing information about a single image.
368
369
        Note:
370
            If the image file doesn't exist locally and download is enabled, this
371
            method will download the file and fix the metadata.
372
        """
373
        image_dict = image_row.to_dict()
374
        filename = image_dict[self._FILENAME_KEY]
375
        image_path = self._root_dir / self._get_image_path(filename)
376
        if not image_path.exists():
377
            self._download_file_if_needed(image_path)
378
            self._fix_image(image_path, image_dict)
379
        report_dict = self._extract_report_dict(image_dict)
380
        image_dict[self._report_key] = report_dict
381
        image = ScalarImage(image_path, **image_dict)
382
        return image
383
384
    def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]:
385
        """Extract radiology report information from the subject dictionary.
386
387
        Extracts the English radiology report components (clinical information,
388
        findings, impressions, and technique) from the subject dictionary and
389
        removes these keys from the original dictionary.
390
391
        Args:
392
            subject_dict: Image metadata including report fields.
393
394
        Note:
395
            This method modifies the input subject_dict by removing the report keys.
396
        """
397
        report_keys = [
398
            'ClinicalInformation_EN',
399
            'Findings_EN',
400
            'Impressions_EN',
401
            'Technique_EN',
402
        ]
403
        report_dict = {}
404
        for key in report_keys:
405
            report_dict[key] = subject_dict.pop(key)
406
        return report_dict
407
408
    @staticmethod
409
    def _get_image_path(filename: str) -> Path:
410
        """Construct the relative path to an image file within the dataset structure.
411
412
        Parses the filename to determine the hierarchical directory structure
413
        where the image is stored in the CT-RATE dataset.
414
415
        Args:
416
            filename: The name of the image file (e.g., 'train_2_a_1.nii.gz').
417
418
        Returns:
419
            Path: The relative path to the image file within the dataset directory.
420
421
        Example:
422
            >>> path = CtRate._get_image_path('train_2_a_1.nii.gz')
423
            # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz')
424
        """
425
        parts = filename.split('_')
426
        base_dir = 'dataset'
427
        split_dir = parts[0]
428
        level1 = f'{parts[0]}_{parts[1]}'
429
        level2 = f'{level1}_{parts[2]}'
430
        return Path(base_dir, split_dir, level1, level2, filename)
431
432
    @staticmethod
433
    def _fix_image(path: Path, metadata: dict[str, str]) -> None:
434
        """Fix the metadata of a downloaded image file.
435
436
        The original NIfTI files in the CT-RATE dataset have incorrect spatial
437
        metadata. This method reads the image, fixes the spacing, origin, and
438
        orientation based on the metadata provided in the CSV, and applies the correct
439
        rescaling to convert to Hounsfield units.
440
441
        Args:
442
            path: The path to the image file to fix.
443
            metadata: A dictionary containing image metadata including spacing,
444
                orientation, and rescale parameters.
445
446
        Note:
447
            This method overwrites the original file with the fixed version.
448
            The fixed image is stored as INT16 with proper HU values.
449
        """
450
        # Adapted from https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/download_scripts/fix_metadata.py
451
        image = sitk.ReadImage(str(path))
452
453
        spacing_x, spacing_y = map(float, ast.literal_eval(metadata['XYSpacing']))
454
        spacing_z = metadata['ZSpacing']
455
        image.SetSpacing((spacing_x, spacing_y, spacing_z))
456
457
        image.SetOrigin(ast.literal_eval(metadata['ImagePositionPatient']))
458
459
        orientation = ast.literal_eval(metadata['ImageOrientationPatient'])
460
        row_cosine, col_cosine = orientation[:3], orientation[3:6]
461
        z_cosine = np.cross(row_cosine, col_cosine).tolist()
462
        image.SetDirection(row_cosine + col_cosine + z_cosine)
463
464
        RescaleIntercept = metadata['RescaleIntercept']
465
        RescaleSlope = metadata['RescaleSlope']
466
        adjusted_hu = image * RescaleSlope + RescaleIntercept
467
        cast_int16 = sitk.Cast(adjusted_hu, sitk.sitkInt16)
468
469
        sitk.WriteImage(cast_int16, str(path))
470