Passed
Push — main ( aa4098...a77a3c )
by Fernando
01:29
created

torchio.transforms.preprocessing.intensity.pca   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 151
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 12
eloc 96
dl 0
loc 151
rs 10
c 0
b 0
f 0

1 Function

Rating   Name   Duplication   Size   Complexity  
B _compute_pca() 0 47 8

2 Methods

Rating   Name   Duplication   Size   Complexity  
A PCA.apply_transform() 0 16 3
A PCA.__init__() 0 31 1
1
from __future__ import annotations
2
3
from typing import Any
4
5
import numpy as np
6
from einops import rearrange
7
8
from ....data.image import ScalarImage
9
from ....data.subject import Subject
10
from ....external.imports import get_sklearn
11
from ...intensity_transform import IntensityTransform
12
13
14
class PCA(IntensityTransform):
15
    """Compute principal component analysis (PCA) of an image.
16
17
    PCA can be useful to visualize embeddings generated by a neural network.
18
    See for example Figure 8 in `Cluster and Predict Latent Patches for
19
    Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_.
20
21
    Args:
22
        num_components: Number of components to compute.
23
        keep_components: Number of components to keep in the output image.
24
            If ``None``, all components are kept.
25
        whiten: If ``True``, the components are normalized to have unit variance.
26
        normalize: If ``True``, all components are divided by the standard
27
            deviation of the first component.
28
        make_skewness_positive: If ``True``, the skewness of each component is
29
            made positive by multiplying the component by -1 if its skewness is
30
            negative.
31
        values_range: If not ``None``, these values are linearly mappped to
32
            :math:`[0, 1]`.
33
        clip: If ``True``, the output values are clipped to :math:`[0, 1]`.
34
        pca_kwargs: Additional keyword arguments to pass to
35
            :class:`sklearn.decomposition.PCA`.
36
37
    Example:
38
39
    >>> import torchio as tio
40
    >>> from torchio.visualization import build_image_from_reference
41
    >>> ct = my_preprocessed_ct_image  # Assume this is a preprocessed CT image
42
    >>> ct
43
    ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB)
44
    >>> embedding_tensor = model(ct.data[None])[0]  # `model` is some pre-trained neural network
45
    >>> embedding_image = ToReferenceSpace(ct)(embedding_tensor)
46
    >>> embedding_image
47
    ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB)
48
    >>> pca = tio.PCA()(embedding_image)
49
    >>> pca
50
    ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB)
51
    """
52
53
    def __init__(
54
        self,
55
        num_components: int = 6,
56
        *,
57
        keep_components: int | None = 3,
58
        whiten: bool = True,
59
        normalize: bool = True,
60
        make_skewness_positive: bool = True,
61
        values_range: tuple[float, float] | None = (-2.3, 2.3),
62
        clip: bool = True,
63
        pca_kwargs: dict[str, Any] | None = None,
64
        **kwargs,
65
    ):
66
        super().__init__(**kwargs)
67
        self.num_components = num_components
68
        self.keep_components = keep_components
69
        self.whiten = whiten
70
        self.normalize = normalize
71
        self.make_skewness_positive = make_skewness_positive
72
        self.values_range = values_range
73
        self.clip = clip
74
        self.pca_kwargs = pca_kwargs
75
        self.args_names = [
76
            'num_components',
77
            'keep_components',
78
            'whiten',
79
            'normalize',
80
            'make_skewness_positive',
81
            'values_range',
82
            'clip',
83
            'pca_kwargs',
84
        ]
85
86
    def apply_transform(self, subject: Subject) -> Subject:
87
        for image in self.get_images(subject):
88
            kwargs = {} if self.pca_kwargs is None else self.pca_kwargs
89
            pca_image = _compute_pca(
90
                image,
91
                num_components=self.num_components,
92
                keep_components=self.keep_components,
93
                whiten=self.whiten,
94
                normalize=self.normalize,
95
                make_skewness_positive=self.make_skewness_positive,
96
                values_range=self.values_range,
97
                clip=self.clip,
98
                **kwargs,
99
            )
100
            image.set_data(pca_image.data)
101
        return subject
102
103
104
def _compute_pca(
105
    embeddings: ScalarImage,
106
    num_components: int,
107
    keep_components: int | None,
108
    whiten: bool,
109
    normalize: bool,
110
    make_skewness_positive: bool,
111
    values_range: tuple[float, float] | None,
112
    clip: bool,
113
    **pca_kwargs,
114
) -> ScalarImage:
115
    # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py
116
    # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1].
117
    sklearn = get_sklearn()
118
    PCA = sklearn.decomposition.PCA
119
120
    data = embeddings.numpy()
121
    _, size_x, size_y, size_z = data.shape
122
    X = rearrange(data, 'c x y z -> (x y z) c')
123
    pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs)
124
    projected: np.ndarray = pca.fit_transform(X).T
125
    if normalize:
126
        projected /= projected[0].std()
127
    if make_skewness_positive:
128
        for component in projected:
129
            third_cumulant = np.mean(component**3)
130
            second_cumulant = np.mean(component**2)
131
            skewness = third_cumulant / second_cumulant ** (3 / 2)
132
            if skewness < 0:
133
                component *= -1
134
    grid: np.ndarray = rearrange(
135
        projected,
136
        'c (x y z) -> c x y z',
137
        x=size_x,
138
        y=size_y,
139
        z=size_z,
140
    )
141
    if values_range is not None:
142
        vmin, vmax = values_range
143
    else:
144
        vmin, vmax = grid.min(), grid.max()
145
    grid = (grid - vmin) / (vmax - vmin)
146
    if clip:
147
        grid = np.clip(grid, 0, 1)
148
    if keep_components is not None:
149
        grid = grid[:keep_components]
150
    return ScalarImage(tensor=grid, affine=embeddings.affine)
151