Passed
Pull Request — main (#1278)
by Fernando
02:51 queued 01:03
created

torchio.datasets.ct_rate   A

Complexity

Total Complexity 29

Size/Duplication

Total Lines 469
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 29
eloc 218
dl 0
loc 469
rs 10
c 0
b 0
f 0

17 Methods

Rating   Name   Duplication   Size   Complexity  
A CtRate._parse_split() 0 21 3
A CtRate._get_image_path() 0 23 1
A CtRate._get_subjects_list() 0 20 1
A CtRate._download_file() 0 36 2
A CtRate._get_csv() 0 21 2
A CtRate._keep_n_subjects() 0 10 1
A CtRate.__init__() 0 23 2
A CtRate._download_file_if_needed() 0 17 2
A CtRate._get_metadata() 0 34 2
A CtRate._extract_report_dict() 0 23 2
A CtRate._instantiate_image() 0 24 2
A CtRate._get_labels() 0 10 1
A CtRate._merge() 0 21 1
A CtRate._get_subject() 0 22 2
A CtRate._get_csv_prefix() 0 16 3
A CtRate._get_reports() 0 10 1
A CtRate._fix_image() 0 38 1
1
from __future__ import annotations
2
3
import ast
4
import multiprocessing
5
from pathlib import Path
6
from typing import TYPE_CHECKING
7
from typing import Literal
8
from typing import Union
9
10
import numpy as np
11
import SimpleITK as sitk
12
from tqdm.contrib.concurrent import thread_map
13
14
from ..data.dataset import SubjectsDataset
15
from ..data.image import ScalarImage
16
from ..data.subject import Subject
17
from ..external.imports import get_huggingface_hub
18
from ..external.imports import get_pandas
19
from ..types import TypePath
20
21
if TYPE_CHECKING:
22
    import pandas as pd
23
24
25
TypeSplit = Union[
26
    Literal['train'],
27
    Literal['valid'],
28
    Literal['validation'],
29
]
30
31
32
class CtRate(SubjectsDataset):
33
    """CT-RATE dataset.
34
35
    This class provides access to
36
    `CT-RATE <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_,
37
    which contains chest CT scans with associated radiology reports and
38
    abnormality labels. The dataset can be automatically downloaded from
39
    Hugging Face if needed.
40
41
    Args:
42
        root: Root directory where the dataset is stored or will be downloaded to.
43
        split: Dataset split to use, either ``'train'`` or ``'validation'``.
44
        token: Hugging Face token for accessing gated repositories. Alternatively,
45
            login using `huggingface-cli login` to cache the token.
46
        download: If True, download the dataset if files are not found locally.
47
        num_subjects: Optional limit on the number of subjects to load (useful for
48
            testing). If ``None``, all subjects in the split are loaded.
49
        report_key: Key to use for storing radiology reports in the Subject metadata.
50
        sizes: List of image sizes (in pixels) to include. Default: [512, 768, 1024].
51
        **kwargs: Additional arguments for SubjectsDataset.
52
53
    Examples:
54
        >>> dataset = CtRate('/path/to/data', split='train', download=True)
55
    """
56
57
    _REPO_ID = 'ibrahimhamamci/CT-RATE'
58
    _FILENAME_KEY = 'VolumeName'
59
    _SIZES = [512, 768, 1024]
60
    ABNORMALITIES = [
61
        'Medical material',
62
        'Arterial wall calcification',
63
        'Cardiomegaly',
64
        'Pericardial effusion',
65
        'Coronary artery wall calcification',
66
        'Hiatal hernia',
67
        'Lymphadenopathy',
68
        'Emphysema',
69
        'Atelectasis',
70
        'Lung nodule',
71
        'Lung opacity',
72
        'Pulmonary fibrotic sequela',
73
        'Pleural effusion',
74
        'Mosaic attenuation pattern',
75
        'Peribronchial thickening',
76
        'Consolidation',
77
        'Bronchiectasis',
78
        'Interlobular septal thickening',
79
    ]
80
81
    def __init__(
82
        self,
83
        root: TypePath,
84
        split: TypeSplit = 'train',
85
        *,
86
        token: str | None = None,
87
        download: bool = False,
88
        num_subjects: int | None = None,
89
        report_key: str = 'report',
90
        sizes: list[int] | None = None,
91
        **kwargs,
92
    ):
93
        self._root_dir = Path(root)
94
        self._token = token
95
        self._download = download
96
        self._num_subjects = num_subjects
97
        self._report_key = report_key
98
        self._sizes = self._SIZES if sizes is None else sizes
99
100
        self._split = self._parse_split(split)
101
        self.metadata = self._get_metadata()
102
        subjects_list = self._get_subjects_list(self.metadata)
103
        super().__init__(subjects_list, **kwargs)
104
105
    @staticmethod
106
    def _parse_split(split: str) -> str:
107
        """Normalize the split name.
108
109
        Converts 'validation' to 'valid' and validates that the split name
110
        is one of the allowed values.
111
112
        Args:
113
            split: The split name to parse ('train', 'valid', or 'validation').
114
115
        Returns:
116
            str: Normalized split name ('train' or 'valid').
117
118
        Raises:
119
            ValueError: If the split name is not one of the allowed values.
120
        """
121
        if split in ['valid', 'validation']:
122
            return 'valid'
123
        if split not in ['train', 'valid']:
124
            raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'")
125
        return split
126
127
    def _get_csv(
128
        self,
129
        dirname: str,
130
        filename: str,
131
    ) -> pd.DataFrame:
132
        """Download (if needed) and load a CSV file from the dataset.
133
134
        Load a CSV file from the specified directory within the dataset.
135
        If the file doesn't exist and download is enabled, download it.
136
137
        Args:
138
            dirname: Directory name within 'dataset/' where the CSV is located.
139
            filename: Name of the CSV file to load.
140
        """
141
        subfolder = Path(f'dataset/{dirname}')
142
        path = Path(self._root_dir, subfolder, filename)
143
        if not path.exists():
144
            self._download_file_if_needed(path)
145
        pd = get_pandas()
146
        table = pd.read_csv(path)
147
        return table
148
149
    def _get_csv_prefix(self, expand_validation: bool = True) -> str:
150
        """Get the prefix for CSV filenames based on the current split.
151
152
        Returns the appropriate prefix for CSV filenames based on the current split.
153
        For the validation split, can either return 'valid' or 'validation' depending
154
        on the expand_validation parameter.
155
156
        Args:
157
            expand_validation: If ``True`` and split is ``'valid'``, return
158
                ``'validation'``. Otherwise, return the split name as is.
159
        """
160
        if expand_validation and self._split == 'valid':
161
            prefix = 'validation'
162
        else:
163
            prefix = self._split
164
        return prefix
165
166
    def _get_metadata(self) -> pd.DataFrame:
167
        """Load and process the dataset metadata.
168
169
        Loads metadata from the appropriate CSV file, filters images by size,
170
        extracts subject, scan, and reconstruction IDs from filenames, and
171
        merges in reports and abnormality labels.
172
        """
173
        dirname = 'metadata'
174
        prefix = self._get_csv_prefix()
175
        filename = f'{prefix}_metadata.csv'
176
        metadata = self._get_csv(dirname, filename)
177
178
        # Exclude images with size not in self._sizes
179
        rows_int = metadata['Rows'].astype(int)
180
        metadata = metadata[rows_int.isin(self._sizes)]
181
182
        index_columns = [
183
            'subject_id',
184
            'scan_id',
185
            'reconstruction_id',
186
        ]
187
        pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz'
188
        metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern)
189
190
        if self._num_subjects is not None:
191
            metadata = self._keep_n_subjects(metadata, self._num_subjects)
192
193
        # Add reports and abnormality labels to metadata, keeping only the rows for the
194
        # images in the metadata table
195
        metadata = self._merge(metadata, self._get_reports())
196
        metadata = self._merge(metadata, self._get_labels())
197
198
        metadata.set_index(index_columns, inplace=True)
199
        return metadata
200
201
    def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 21 is False. Are you sure this can never be the case?
Loading history...
202
        """Merge a new dataframe into the base dataframe using the filename as the key.
203
204
        This method performs a left join between ``base_df`` and ``new_df`` using the
205
        volume filename as the join key, ensuring that all records from ``base_df`` are
206
        preserved while matching data from ``new_df`` is added.
207
208
        Args:
209
            base_df: The primary dataframe to merge into.
210
            new_df: The dataframe containing additional data to be merged.
211
212
        Returns:
213
            pd.DataFrame: The merged dataframe with all rows from base_df and
214
            matching columns from new_df.
215
        """
216
        pd = get_pandas()
217
        return pd.merge(
218
            base_df,
219
            new_df,
220
            on=self._FILENAME_KEY,
221
            how='left',
222
        )
223
224
    def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 21 is False. Are you sure this can never be the case?
Loading history...
225
        """Limit the metadata to the first ``n`` subjects.
226
227
        Args:
228
            metadata: The complete metadata dataframe.
229
            n: Maximum number of subjects to keep.
230
        """
231
        unique_subjects = metadata['subject_id'].unique()
232
        selected_subjects = unique_subjects[:n]
233
        return metadata[metadata['subject_id'].isin(selected_subjects)]
234
235
    def _get_reports(self) -> pd.DataFrame:
236
        """Load the radiology reports associated with the CT scans.
237
238
        Retrieves the CSV file containing radiology reports for the current split
239
        (train or validation).
240
        """
241
        dirname = 'radiology_text_reports'
242
        prefix = self._get_csv_prefix()
243
        filename = f'{prefix}_reports.csv'
244
        return self._get_csv(dirname, filename)
245
246
    def _get_labels(self) -> pd.DataFrame:
247
        """Load the abnormality labels for the CT scans.
248
249
        Retrieves the CSV file containing predicted abnormality labels for the
250
        current split.
251
        """
252
        dirname = 'multi_abnormality_labels'
253
        prefix = self._get_csv_prefix(expand_validation=False)
254
        filename = f'{prefix}_predicted_labels.csv'
255
        return self._get_csv(dirname, filename)
256
257
    def _download_file_if_needed(self, path: Path) -> None:
258
        """Download a file if it does not exist locally and ``download`` is enabled.
259
260
        Checks if the specified file exists at the given path. If not, and ``download``
261
        is enabled, it downloads the file; otherwise, it raises an error.
262
263
        Args:
264
            path: The local file path to check and potentially download to.
265
266
        Raises:
267
            FileNotFoundError: If the file doesn't exist and download=False.
268
        """
269
        if self._download:
270
            self._download_file(path)
271
        else:
272
            raise FileNotFoundError(
273
                f'File "{path}" not found.'
274
                " Set 'download=True' to download the dataset"
275
            )
276
277
    def _download_file(self, path: Path) -> None:
278
        """Download a file from the Hugging Face repository to the specified path.
279
280
        Downloads a single file from the CT-RATE dataset repository on Hugging Face,
281
        preserving the directory structure.
282
283
        Args:
284
            path: The destination path where the file will be saved.
285
286
        Raises:
287
            RuntimeError: If the repository is gated and no valid token is provided.
288
289
        Note:
290
            This method requires access to the CT-RATE repository. If the repository
291
            is gated, a valid Hugging Face token with appropriate permissions must
292
            be provided, or the user must log in using the Hugging Face CLI.
293
        """
294
        relative_path = path.relative_to(self._root_dir)
295
        huggingface_hub = get_huggingface_hub()
296
        try:
297
            huggingface_hub.hf_hub_download(
298
                repo_id=self._REPO_ID,
299
                repo_type='dataset',
300
                token=self._token,
301
                subfolder=str(relative_path.parent),
302
                filename=relative_path.name,
303
                local_dir=self._root_dir,
304
            )
305
        except huggingface_hub.errors.GatedRepoError as e:
306
            message = (
307
                f'The dataset "{self._REPO_ID}" is gated. Visit'
308
                f' https://huggingface.co/datasets/{self._REPO_ID}, accept the'
309
                ' terms and conditions, and log in or create and pass a token to'
310
                ' the `token` argument'
311
            )
312
            raise RuntimeError(message) from e
313
314
    def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 21 is False. Are you sure this can never be the case?
Loading history...
315
        """Create a list of Subject instances from the metadata.
316
317
        Processes the metadata to create Subject objects, each containing one or more
318
        CT images. Processing is performed in parallel.
319
320
        Note:
321
            This method uses parallelization to improve performance when creating
322
            multiple Subject instances.
323
        """
324
        df_no_index = metadata.reset_index()
325
        num_subjects = df_no_index['subject_id'].nunique()
326
        iterable = df_no_index.groupby('subject_id')
327
        subjects = thread_map(
328
            self._get_subject,
329
            iterable,
330
            max_workers=multiprocessing.cpu_count(),
331
            total=num_subjects,
332
        )
333
        return subjects
334
335
    def _get_subject(
336
        self,
337
        subject_id_and_metadata: tuple[str, pd.DataFrame],
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 21 is False. Are you sure this can never be the case?
Loading history...
338
    ) -> Subject:
339
        """Create a Subject instance for a specific subject.
340
341
        Processes all images belonging to a single subject and creates a Subject
342
        object containing those images.
343
344
        Args:
345
            subject_id_and_metadata: A tuple containing the subject ID (string) and a
346
                DataFrame containing metadata for all images associated to that subject.
347
        """
348
        subject_id, subject_df = subject_id_and_metadata
349
        subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id}
350
        for _, image_row in subject_df.iterrows():
351
            image = self._instantiate_image(image_row)
352
            scan_id = image_row['scan_id']
353
            reconstruction_id = image_row['reconstruction_id']
354
            image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}'
355
            subject_dict[image_key] = image
356
        return Subject(**subject_dict)  # type: ignore[arg-type]
357
358
    def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 21 is False. Are you sure this can never be the case?
Loading history...
359
        """Create a ScalarImage object for a specific image.
360
361
        Processes a row from the metadata DataFrame to create a ScalarImage object,
362
        downloading the image if necessary and extracting the radiology report.
363
364
        Args:
365
            image_row: A pandas Series representing a row from the metadata DataFrame,
366
                containing information about a single image.
367
368
        Note:
369
            If the image file doesn't exist locally and download is enabled, this
370
            method will download the file and fix the metadata.
371
        """
372
        image_dict = image_row.to_dict()
373
        filename = image_dict[self._FILENAME_KEY]
374
        image_path = self._root_dir / self._get_image_path(filename)
375
        if not image_path.exists():
376
            self._download_file_if_needed(image_path)
377
            self._fix_image(image_path, image_dict)
378
        report_dict = self._extract_report_dict(image_dict)
379
        image_dict[self._report_key] = report_dict
380
        image = ScalarImage(image_path, **image_dict)
381
        return image
382
383
    def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]:
384
        """Extract radiology report information from the subject dictionary.
385
386
        Extracts the English radiology report components (clinical information,
387
        findings, impressions, and technique) from the subject dictionary and
388
        removes these keys from the original dictionary.
389
390
        Args:
391
            subject_dict: Image metadata including report fields.
392
393
        Note:
394
            This method modifies the input subject_dict by removing the report keys.
395
        """
396
        report_keys = [
397
            'ClinicalInformation_EN',
398
            'Findings_EN',
399
            'Impressions_EN',
400
            'Technique_EN',
401
        ]
402
        report_dict = {}
403
        for key in report_keys:
404
            report_dict[key] = subject_dict.pop(key)
405
        return report_dict
406
407
    @staticmethod
408
    def _get_image_path(filename: str) -> Path:
409
        """Construct the relative path to an image file within the dataset structure.
410
411
        Parses the filename to determine the hierarchical directory structure
412
        where the image is stored in the CT-RATE dataset.
413
414
        Args:
415
            filename: The name of the image file (e.g., 'train_2_a_1.nii.gz').
416
417
        Returns:
418
            Path: The relative path to the image file within the dataset directory.
419
420
        Example:
421
            >>> path = CtRate._get_image_path('train_2_a_1.nii.gz')
422
            # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz')
423
        """
424
        parts = filename.split('_')
425
        base_dir = 'dataset'
426
        split_dir = parts[0]
427
        level1 = f'{parts[0]}_{parts[1]}'
428
        level2 = f'{level1}_{parts[2]}'
429
        return Path(base_dir, split_dir, level1, level2, filename)
430
431
    @staticmethod
432
    def _fix_image(path: Path, metadata: dict[str, str]) -> None:
433
        """Fix the metadata of a downloaded image file.
434
435
        The original NIfTI files in the CT-RATE dataset have incorrect spatial
436
        metadata. This method reads the image, fixes the spacing, origin, and
437
        orientation based on the metadata provided in the CSV, and applies the correct
438
        rescaling to convert to Hounsfield units.
439
440
        Args:
441
            path: The path to the image file to fix.
442
            metadata: A dictionary containing image metadata including spacing,
443
                orientation, and rescale parameters.
444
445
        Note:
446
            This method overwrites the original file with the fixed version.
447
            The fixed image is stored as INT16 with proper HU values.
448
        """
449
        # Adapted from https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/download_scripts/fix_metadata.py
450
        image = sitk.ReadImage(str(path))
451
452
        spacing_x, spacing_y = map(float, ast.literal_eval(metadata['XYSpacing']))
453
        spacing_z = metadata['ZSpacing']
454
        image.SetSpacing((spacing_x, spacing_y, spacing_z))
455
456
        image.SetOrigin(ast.literal_eval(metadata['ImagePositionPatient']))
457
458
        orientation = ast.literal_eval(metadata['ImageOrientationPatient'])
459
        row_cosine, col_cosine = orientation[:3], orientation[3:6]
460
        z_cosine = np.cross(row_cosine, col_cosine).tolist()
461
        image.SetDirection(row_cosine + col_cosine + z_cosine)
462
463
        RescaleIntercept = metadata['RescaleIntercept']
464
        RescaleSlope = metadata['RescaleSlope']
465
        adjusted_hu = image * RescaleSlope + RescaleIntercept
466
        cast_int16 = sitk.Cast(adjusted_hu, sitk.sitkInt16)
467
468
        sitk.WriteImage(cast_int16, str(path))
469