Passed
Pull Request — main (#1278)
by
unknown
02:48 queued 01:23
created

torchio.datasets.ct_rate   A

Complexity

Total Complexity 39

Size/Duplication

Total Lines 494
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 266
dl 0
loc 494
rs 9.28
c 0
b 0
f 0
wmc 39

17 Methods

Rating   Name   Duplication   Size   Complexity  
A CtRate._parse_split() 0 21 3
A CtRate._get_image_path() 0 23 1
A CtRate._get_subjects_list() 0 20 1
A CtRate._get_csv() 0 16 1
A CtRate._keep_n_subjects() 0 10 1
A CtRate.__init__() 0 21 2
A CtRate._get_metadata() 0 34 2
A CtRate._extract_report_dict() 0 23 2
A CtRate._instantiate_image() 0 16 1
A CtRate._get_labels() 0 10 1
A CtRate._merge() 0 21 1
A CtRate._get_subject() 0 22 2
A CtRate._get_csv_prefix() 0 16 3
A CtRate._get_reports() 0 10 1
A CtRate._fix_image() 0 40 3
B CtRate._copy_not_images() 0 27 7
B CtRate.fix_metadata() 0 55 7
1
from __future__ import annotations
2
3
import ast
4
import enum
5
import multiprocessing
6
import shutil
7
from pathlib import Path
8
from typing import TYPE_CHECKING
9
from typing import Literal
10
from typing import Union
11
12
import numpy as np
13
import SimpleITK as sitk
14
from tqdm.auto import tqdm
15
from tqdm.contrib.concurrent import process_map
16
from tqdm.contrib.concurrent import thread_map
17
18
from ..data.dataset import SubjectsDataset
19
from ..data.image import ScalarImage
20
from ..data.subject import Subject
21
from ..external.imports import get_pandas
22
from ..types import TypePath
23
24
if TYPE_CHECKING:
25
    import pandas as pd
26
27
28
TypeSplit = Union[
29
    Literal['train'],
30
    Literal['valid'],
31
    Literal['validation'],
32
]
33
34
TypeParallelism = Literal['thread', 'process', None]
35
36
37
class MetadataIndexColumn(str, enum.Enum):
38
    SUBJECT_ID = 'subject_id'
39
    SCAN_ID = 'scan_id'
40
    RECONSTRUCTION_ID = 'reconstruction_id'
41
42
43
class CtRate(SubjectsDataset):
44
    """CT-RATE dataset.
45
46
    This class provides access to
47
    `CT-RATE <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_,
48
    which contains chest CT scans with associated radiology reports and
49
    abnormality labels.
50
51
    The dataset must have been downloaded previously.
52
53
    Args:
54
        root: Root directory where the dataset has been downloaded.
55
        split: Dataset split to use, either ``'train'`` or ``'validation'``.
56
        token: Hugging Face token for accessing gated repositories. Alternatively,
57
            login using `huggingface-cli login` to cache the token.
58
        num_subjects: Optional limit on the number of subjects to load (useful for
59
            testing). If ``None``, all subjects in the split are loaded.
60
        report_key: Key to use for storing radiology reports in the Subject metadata.
61
        sizes: List of image sizes (in pixels) to include. Default: [512, 768, 1024].
62
        **kwargs: Additional arguments for SubjectsDataset.
63
64
    Examples:
65
        >>> dataset = CtRate('/path/to/data', split='train')
66
    """
67
68
    _REPO_ID = 'ibrahimhamamci/CT-RATE'
69
    _FILENAME_KEY = 'VolumeName'
70
    _SIZES = [512, 768, 1024]
71
    ABNORMALITIES = [
72
        'Medical material',
73
        'Arterial wall calcification',
74
        'Cardiomegaly',
75
        'Pericardial effusion',
76
        'Coronary artery wall calcification',
77
        'Hiatal hernia',
78
        'Lymphadenopathy',
79
        'Emphysema',
80
        'Atelectasis',
81
        'Lung nodule',
82
        'Lung opacity',
83
        'Pulmonary fibrotic sequela',
84
        'Pleural effusion',
85
        'Mosaic attenuation pattern',
86
        'Peribronchial thickening',
87
        'Consolidation',
88
        'Bronchiectasis',
89
        'Interlobular septal thickening',
90
    ]
91
92
    def __init__(
93
        self,
94
        root: TypePath,
95
        split: TypeSplit = 'train',
96
        *,
97
        token: str | None = None,
98
        num_subjects: int | None = None,
99
        report_key: str = 'report',
100
        sizes: list[int] | None = None,
101
        **kwargs,
102
    ):
103
        self._root_dir = Path(root)
104
        self._token = token
105
        self._num_subjects = num_subjects
106
        self._report_key = report_key
107
        self._sizes = self._SIZES if sizes is None else sizes
108
109
        self._split = self._parse_split(split)
110
        self.metadata = self._get_metadata()
111
        subjects_list = self._get_subjects_list(self.metadata)
112
        super().__init__(subjects_list, **kwargs)
113
114
    @staticmethod
115
    def _parse_split(split: str) -> str:
116
        """Normalize the split name.
117
118
        Converts 'validation' to 'valid' and validates that the split name
119
        is one of the allowed values.
120
121
        Args:
122
            split: The split name to parse ('train', 'valid', or 'validation').
123
124
        Returns:
125
            str: Normalized split name ('train' or 'valid').
126
127
        Raises:
128
            ValueError: If the split name is not one of the allowed values.
129
        """
130
        if split in ['valid', 'validation']:
131
            return 'valid'
132
        if split not in ['train', 'valid']:
133
            raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'")
134
        return split
135
136
    def _get_csv(
137
        self,
138
        dirname: str,
139
        filename: str,
140
    ) -> pd.DataFrame:
141
        """Load a CSV file from the specified directory within the dataset.
142
143
        Args:
144
            dirname: Directory name within 'dataset/' where the CSV is located.
145
            filename: Name of the CSV file to load.
146
        """
147
        subfolder = Path(f'dataset/{dirname}')
148
        path = Path(self._root_dir, subfolder, filename)
149
        pd = get_pandas()
150
        table = pd.read_csv(path)
151
        return table
152
153
    def _get_csv_prefix(self, expand_validation: bool = True) -> str:
154
        """Get the prefix for CSV filenames based on the current split.
155
156
        Returns the appropriate prefix for CSV filenames based on the current split.
157
        For the validation split, can either return 'valid' or 'validation' depending
158
        on the expand_validation parameter.
159
160
        Args:
161
            expand_validation: If ``True`` and split is ``'valid'``, return
162
                ``'validation'``. Otherwise, return the split name as is.
163
        """
164
        if expand_validation and self._split == 'valid':
165
            prefix = 'validation'
166
        else:
167
            prefix = self._split
168
        return prefix
169
170
    def _get_metadata(self) -> pd.DataFrame:
171
        """Load and process the dataset metadata.
172
173
        Loads metadata from the appropriate CSV file, filters images by size,
174
        extracts subject, scan, and reconstruction IDs from filenames, and
175
        merges in reports and abnormality labels.
176
        """
177
        dirname = 'metadata'
178
        prefix = self._get_csv_prefix()
179
        filename = f'{prefix}_metadata.csv'
180
        metadata = self._get_csv(dirname, filename)
181
182
        # Exclude images with size not in self._sizes
183
        rows_int = metadata['Rows'].astype(int)
184
        metadata = metadata[rows_int.isin(self._sizes)]
185
186
        index_columns = [
187
            MetadataIndexColumn.SUBJECT_ID.value,
188
            MetadataIndexColumn.SCAN_ID.value,
189
            MetadataIndexColumn.RECONSTRUCTION_ID.value,
190
        ]
191
        pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz'
192
        metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern)
193
194
        if self._num_subjects is not None:
195
            metadata = self._keep_n_subjects(metadata, self._num_subjects)
196
197
        # Add reports and abnormality labels to metadata, keeping only the rows for the
198
        # images in the metadata table
199
        metadata = self._merge(metadata, self._get_reports())
200
        metadata = self._merge(metadata, self._get_labels())
201
202
        metadata.set_index(index_columns, inplace=True)
203
        return metadata
204
205
    def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 24 is False. Are you sure this can never be the case?
Loading history...
206
        """Merge a new dataframe into the base dataframe using the filename as the key.
207
208
        This method performs a left join between ``base_df`` and ``new_df`` using the
209
        volume filename as the join key, ensuring that all records from ``base_df`` are
210
        preserved while matching data from ``new_df`` is added.
211
212
        Args:
213
            base_df: The primary dataframe to merge into.
214
            new_df: The dataframe containing additional data to be merged.
215
216
        Returns:
217
            pd.DataFrame: The merged dataframe with all rows from base_df and
218
            matching columns from new_df.
219
        """
220
        pd = get_pandas()
221
        return pd.merge(
222
            base_df,
223
            new_df,
224
            on=self._FILENAME_KEY,
225
            how='left',
226
        )
227
228
    def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 24 is False. Are you sure this can never be the case?
Loading history...
229
        """Limit the metadata to the first ``n`` subjects.
230
231
        Args:
232
            metadata: The complete metadata dataframe.
233
            n: Maximum number of subjects to keep.
234
        """
235
        unique_subjects = metadata['subject_id'].unique()
236
        selected_subjects = unique_subjects[:n]
237
        return metadata[metadata['subject_id'].isin(selected_subjects)]
238
239
    def _get_reports(self) -> pd.DataFrame:
240
        """Load the radiology reports associated with the CT scans.
241
242
        Retrieves the CSV file containing radiology reports for the current split
243
        (train or validation).
244
        """
245
        dirname = 'radiology_text_reports'
246
        prefix = self._get_csv_prefix()
247
        filename = f'{prefix}_reports.csv'
248
        return self._get_csv(dirname, filename)
249
250
    def _get_labels(self) -> pd.DataFrame:
251
        """Load the abnormality labels for the CT scans.
252
253
        Retrieves the CSV file containing predicted abnormality labels for the
254
        current split.
255
        """
256
        dirname = 'multi_abnormality_labels'
257
        prefix = self._get_csv_prefix(expand_validation=False)
258
        filename = f'{prefix}_predicted_labels.csv'
259
        return self._get_csv(dirname, filename)
260
261
    def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 24 is False. Are you sure this can never be the case?
Loading history...
262
        """Create a list of Subject instances from the metadata.
263
264
        Processes the metadata to create Subject objects, each containing one or more
265
        CT images. Processing is performed in parallel.
266
267
        Note:
268
            This method uses parallelization to improve performance when creating
269
            multiple Subject instances.
270
        """
271
        df_no_index = metadata.reset_index()
272
        num_subjects = df_no_index['subject_id'].nunique()
273
        iterable = df_no_index.groupby('subject_id')
274
        subjects = thread_map(
275
            self._get_subject,
276
            iterable,
277
            max_workers=multiprocessing.cpu_count(),
278
            total=num_subjects,
279
        )
280
        return subjects
281
282
    def _get_subject(
283
        self,
284
        subject_id_and_metadata: tuple[str, pd.DataFrame],
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 24 is False. Are you sure this can never be the case?
Loading history...
285
    ) -> Subject:
286
        """Create a Subject instance for a specific subject.
287
288
        Processes all images belonging to a single subject and creates a Subject
289
        object containing those images.
290
291
        Args:
292
            subject_id_and_metadata: A tuple containing the subject ID (string) and a
293
                DataFrame containing metadata for all images associated to that subject.
294
        """
295
        subject_id, subject_df = subject_id_and_metadata
296
        subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id}
297
        for _, image_row in subject_df.iterrows():
298
            image = self._instantiate_image(image_row)
299
            scan_id = image_row['scan_id']
300
            reconstruction_id = image_row['reconstruction_id']
301
            image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}'
302
            subject_dict[image_key] = image
303
        return Subject(**subject_dict)  # type: ignore[arg-type]
304
305
    def _instantiate_image(self, image_row: pd.Series) -> ScalarImage:
0 ignored issues
show
introduced by
The variable pd does not seem to be defined in case TYPE_CHECKING on line 24 is False. Are you sure this can never be the case?
Loading history...
306
        """Create a ScalarImage object for a specific image.
307
308
        Processes a row from the metadata DataFrame to create a ScalarImage object,
309
310
        Args:
311
            image_row: A pandas Series representing a row from the metadata DataFrame,
312
                containing information about a single image.
313
        """
314
        image_dict = image_row.to_dict()
315
        filename = image_dict[self._FILENAME_KEY]
316
        image_path = self._root_dir / self._get_image_path(filename)
317
        report_dict = self._extract_report_dict(image_dict)
318
        image_dict[self._report_key] = report_dict
319
        image = ScalarImage(image_path, **image_dict)
320
        return image
321
322
    def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]:
323
        """Extract radiology report information from the subject dictionary.
324
325
        Extracts the English radiology report components (clinical information,
326
        findings, impressions, and technique) from the subject dictionary and
327
        removes these keys from the original dictionary.
328
329
        Args:
330
            subject_dict: Image metadata including report fields.
331
332
        Note:
333
            This method modifies the input subject_dict by removing the report keys.
334
        """
335
        report_keys = [
336
            'ClinicalInformation_EN',
337
            'Findings_EN',
338
            'Impressions_EN',
339
            'Technique_EN',
340
        ]
341
        report_dict = {}
342
        for key in report_keys:
343
            report_dict[key] = subject_dict.pop(key)
344
        return report_dict
345
346
    @staticmethod
347
    def _get_image_path(filename: str) -> Path:
348
        """Construct the relative path to an image file within the dataset structure.
349
350
        Parses the filename to determine the hierarchical directory structure
351
        where the image is stored in the CT-RATE dataset.
352
353
        Args:
354
            filename: The name of the image file (e.g., 'train_2_a_1.nii.gz').
355
356
        Returns:
357
            Path: The relative path to the image file within the dataset directory.
358
359
        Example:
360
            >>> path = CtRate._get_image_path('train_2_a_1.nii.gz')
361
            # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz')
362
        """
363
        parts = filename.split('_')
364
        base_dir = 'dataset'
365
        split_dir = parts[0]
366
        level1 = f'{parts[0]}_{parts[1]}'
367
        level2 = f'{level1}_{parts[2]}'
368
        return Path(base_dir, split_dir, level1, level2, filename)
369
370
    @staticmethod
371
    def _fix_image(image: ScalarImage, out_path: Path, *, force: bool = False) -> None:
372
        """Fix the spatial metadata of a CT-RATE image file.
373
374
        The original NIfTI files in the CT-RATE dataset have incorrect spatial
375
        metadata. This method reads the image, fixes the spacing, origin, and
376
        orientation based on the metadata provided in the CSV, and applies the correct
377
        rescaling to convert to Hounsfield units.
378
379
        Args:
380
            in_path: The path to the image file to fix.
381
            out_path: The path where the fixed image will be saved.
382
383
        Note:
384
            This method overwrites the original file with the fixed version.
385
            The fixed image is stored as INT16 with proper HU values.
386
        """
387
        # Adapted from https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/download_scripts/fix_metadata.py
388
        if not force and out_path.exists():
389
            return
390
        spacing_x, spacing_y = map(float, ast.literal_eval(image['XYSpacing']))
391
        spacing_z = image['ZSpacing']
392
        image_sitk = sitk.ReadImage(str(image.path))
393
        image_sitk.SetSpacing((spacing_x, spacing_y, spacing_z))
394
395
        image_sitk.SetOrigin(ast.literal_eval(image['ImagePositionPatient']))
396
397
        orientation = ast.literal_eval(image['ImageOrientationPatient'])
398
        row_cosine, col_cosine = orientation[:3], orientation[3:6]
399
        z_cosine = np.cross(row_cosine, col_cosine).tolist()
400
        image_sitk.SetDirection(row_cosine + col_cosine + z_cosine)
401
402
        RescaleIntercept = image['RescaleIntercept']
403
        RescaleSlope = image['RescaleSlope']
404
        adjusted_hu = image_sitk * RescaleSlope + RescaleIntercept
405
        cast_int16 = sitk.Cast(adjusted_hu, sitk.sitkInt16)
406
407
        out_path.parent.mkdir(parents=True, exist_ok=True)
408
        sitk.WriteImage(cast_int16, str(out_path))
409
        return cast_int16
410
411
    def _copy_not_images(self, out_dir: Path) -> None:
412
        """Copy all files from the root directory except the images."""
413
        for path in self._root_dir.iterdir():
414
            if path.name == 'dataset':
415
                for subdirectory in path.iterdir():
416
                    if subdirectory.name in ['train', 'valid']:
417
                        continue
418
                    print(
419
                        f'Copying {subdirectory} to {out_dir / subdirectory.relative_to(self._root_dir)}'
420
                    )
421
                    shutil.copytree(
422
                        subdirectory,
423
                        out_dir / subdirectory.relative_to(self._root_dir),
424
                        dirs_exist_ok=True,
425
                    )
426
            elif path.name.startswith('.'):
427
                continue
428
            elif path.is_dir():
429
                print(f'Copying {path} to {out_dir / path.name}')
430
                shutil.copytree(
431
                    path,
432
                    out_dir / path.name,
433
                    dirs_exist_ok=True,
434
                )
435
            else:
436
                print(f'Copying {path} to {out_dir / path.name}')
437
                shutil.copy(path, out_dir / path.name)
438
439
    def fix_metadata(
440
        self,
441
        out_dir: str | Path,
442
        parallelism: TypeParallelism = None,
443
    ) -> CtRate:
444
        """Fix the metadata of all images in the dataset.
445
446
        Reads each image, applies the correct spatial metadata, and saves the fixed
447
        image to the specified output directory.
448
449
        Args:
450
            out_dir: The directory where the fixed images will be saved.
451
        """
452
        out_dir = Path(out_dir)
453
        out_dir.mkdir(parents=True, exist_ok=True)
454
        # self._copy_not_images(out_dir)
455
        images = []
456
        out_paths = []
457
        for subject in self.dry_iter():
458
            for image in subject.get_images():
459
                out_path = out_dir / image.path.relative_to(self._root_dir)
460
                images.append(image)
461
                out_paths.append(out_path)
462
        if parallelism == 'thread':
463
            thread_map(
464
                self._fix_image,
465
                images,
466
                out_paths,
467
                max_workers=multiprocessing.cpu_count(),
468
                desc='Fixing metadata',
469
            )
470
        elif parallelism == 'process':
471
            process_map(
472
                self._fix_image,
473
                images,
474
                out_paths,
475
                max_workers=multiprocessing.cpu_count(),
476
                desc='Fixing metadata',
477
            )
478
        else:
479
            zipped = zip(images, out_paths)
480
            with tqdm(total=len(images), desc='Fixing metadata') as pbar:
481
                for image, out_path in zipped:
482
                    pbar.set_description(f'Fixing {image.path.name}')
483
                    self._fix_image(image, out_path)
484
                    pbar.update(1)
485
        new_dataset = CtRate(
486
            out_dir,
487
            split=self._split,
488
            token=self._token,
489
            num_subjects=self._num_subjects,
490
            report_key=self._report_key,
491
            sizes=self._sizes,
492
        )
493
        return new_dataset
494