Passed
Pull Request — main (#1349)
by Fernando
01:33
created

torchio.transforms.preprocessing.intensity.pca   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 149
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 95
dl 0
loc 149
rs 10
c 0
b 0
f 0
wmc 12

1 Function

Rating   Name   Duplication   Size   Complexity  
B _compute_pca() 0 47 8

2 Methods

Rating   Name   Duplication   Size   Complexity  
A PCA.apply_transform() 0 16 3
A PCA.__init__() 0 31 1
1
from typing import Any
2
3
import numpy as np
4
from einops import rearrange
5
6
from ....data.image import ScalarImage
7
from ....data.subject import Subject
8
from ....external.imports import get_sklearn
9
from ...intensity_transform import IntensityTransform
10
11
12
class PCA(IntensityTransform):
13
    """Compute principal component analysis (PCA) of an image.
14
15
    PCA can be useful to visualize embeddings generated by a neural network.
16
    See for example Figure 8 in `Cluster and Predict Latent Patches for
17
    Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_.
18
19
    Args:
20
        num_components: Number of components to compute.
21
        keep_components: Number of components to keep in the output image.
22
            If ``None``, all components are kept.
23
        whiten: If ``True``, the components are normalized to have unit variance.
24
        normalize: If ``True``, all components are divided by the standard
25
            deviation of the first component.
26
        make_skewness_positive: If ``True``, the skewness of each component is
27
            made positive by multiplying the component by -1 if its skewness is
28
            negative.
29
        values_range: If not ``None``, these values are linearly mappped to
30
            :math:`[0, 1]`.
31
        clip: If ``True``, the output values are clipped to :math:`[0, 1]`.
32
        pca_kwargs: Additional keyword arguments to pass to
33
            :class:`sklearn.decomposition.PCA`.
34
35
    Example:
36
37
    >>> import torchio as tio
38
    >>> from torchio.visualization import build_image_from_reference
39
    >>> ct = my_preprocessed_ct_image  # Assume this is a preprocessed CT image
40
    >>> ct
41
    ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB)
42
    >>> embedding_tensor = model(ct.data[None])[0]  # `model` is some pre-trained neural network
43
    >>> embedding_image = ToReferenceSpace(ct)(embedding_tensor)
44
    >>> embedding_image
45
    ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB)
46
    >>> pca = tio.PCA()(embedding_image)
47
    >>> pca
48
    ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB)
49
    """
50
51
    def __init__(
52
        self,
53
        num_components: int = 6,
54
        *,
55
        keep_components: int | None = 3,
56
        whiten: bool = True,
57
        normalize: bool = True,
58
        make_skewness_positive: bool = True,
59
        values_range: tuple[float, float] | None = (-2.3, 2.3),
60
        clip: bool = True,
61
        pca_kwargs: dict[str, Any] | None = None,
62
        **kwargs,
63
    ):
64
        super().__init__(**kwargs)
65
        self.num_components = num_components
66
        self.keep_components = keep_components
67
        self.whiten = whiten
68
        self.normalize = normalize
69
        self.make_skewness_positive = make_skewness_positive
70
        self.values_range = values_range
71
        self.clip = clip
72
        self.pca_kwargs = pca_kwargs
73
        self.args_names = [
74
            'num_components',
75
            'keep_components',
76
            'whiten',
77
            'normalize',
78
            'make_skewness_positive',
79
            'values_range',
80
            'clip',
81
            'pca_kwargs',
82
        ]
83
84
    def apply_transform(self, subject: Subject) -> Subject:
85
        for image in self.get_images(subject):
86
            kwargs = {} if self.pca_kwargs is None else self.pca_kwargs
87
            pca_image = _compute_pca(
88
                image,
89
                num_components=self.num_components,
90
                keep_components=self.keep_components,
91
                whiten=self.whiten,
92
                normalize=self.normalize,
93
                make_skewness_positive=self.make_skewness_positive,
94
                values_range=self.values_range,
95
                clip=self.clip,
96
                **kwargs,
97
            )
98
            image.set_data(pca_image.data)
99
        return subject
100
101
102
def _compute_pca(
103
    embeddings: ScalarImage,
104
    num_components: int,
105
    keep_components: int | None,
106
    whiten: bool,
107
    normalize: bool,
108
    make_skewness_positive: bool,
109
    values_range: tuple[float, float] | None,
110
    clip: bool,
111
    **pca_kwargs,
112
) -> ScalarImage:
113
    # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py
114
    # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1].
115
    sklearn = get_sklearn()
116
    PCA = sklearn.decomposition.PCA
117
118
    data = embeddings.numpy()
119
    _, size_x, size_y, size_z = data.shape
120
    X = rearrange(data, 'c x y z -> (x y z) c')
121
    pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs)
122
    projected: np.ndarray = pca.fit_transform(X).T
123
    if normalize:
124
        projected /= projected[0].std()
125
    if make_skewness_positive:
126
        for component in projected:
127
            third_cumulant = np.mean(component**3)
128
            second_cumulant = np.mean(component**2)
129
            skewness = third_cumulant / second_cumulant ** (3 / 2)
130
            if skewness < 0:
131
                component *= -1
132
    grid: np.ndarray = rearrange(
133
        projected,
134
        'c (x y z) -> c x y z',
135
        x=size_x,
136
        y=size_y,
137
        z=size_z,
138
    )
139
    if values_range is not None:
140
        vmin, vmax = values_range
141
    else:
142
        vmin, vmax = grid.min(), grid.max()
143
    grid = (grid - vmin) / (vmax - vmin)
144
    if clip:
145
        grid = np.clip(grid, 0, 1)
146
    if keep_components is not None:
147
        grid = grid[:keep_components]
148
    return ScalarImage(tensor=grid, affine=embeddings.affine)
149