Passed
Push — main ( aa4098...a77a3c )
by Fernando
01:29
created

PCA.__init__()   A

Complexity

Conditions 1

Size

Total Lines 31
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 30
nop 11
dl 0
loc 31
rs 9.16
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from __future__ import annotations
2
3
from typing import Any
4
5
import numpy as np
6
from einops import rearrange
7
8
from ....data.image import ScalarImage
9
from ....data.subject import Subject
10
from ....external.imports import get_sklearn
11
from ...intensity_transform import IntensityTransform
12
13
14
class PCA(IntensityTransform):
15
    """Compute principal component analysis (PCA) of an image.
16
17
    PCA can be useful to visualize embeddings generated by a neural network.
18
    See for example Figure 8 in `Cluster and Predict Latent Patches for
19
    Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_.
20
21
    Args:
22
        num_components: Number of components to compute.
23
        keep_components: Number of components to keep in the output image.
24
            If ``None``, all components are kept.
25
        whiten: If ``True``, the components are normalized to have unit variance.
26
        normalize: If ``True``, all components are divided by the standard
27
            deviation of the first component.
28
        make_skewness_positive: If ``True``, the skewness of each component is
29
            made positive by multiplying the component by -1 if its skewness is
30
            negative.
31
        values_range: If not ``None``, these values are linearly mappped to
32
            :math:`[0, 1]`.
33
        clip: If ``True``, the output values are clipped to :math:`[0, 1]`.
34
        pca_kwargs: Additional keyword arguments to pass to
35
            :class:`sklearn.decomposition.PCA`.
36
37
    Example:
38
39
    >>> import torchio as tio
40
    >>> from torchio.visualization import build_image_from_reference
41
    >>> ct = my_preprocessed_ct_image  # Assume this is a preprocessed CT image
42
    >>> ct
43
    ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB)
44
    >>> embedding_tensor = model(ct.data[None])[0]  # `model` is some pre-trained neural network
45
    >>> embedding_image = ToReferenceSpace(ct)(embedding_tensor)
46
    >>> embedding_image
47
    ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB)
48
    >>> pca = tio.PCA()(embedding_image)
49
    >>> pca
50
    ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB)
51
    """
52
53
    def __init__(
54
        self,
55
        num_components: int = 6,
56
        *,
57
        keep_components: int | None = 3,
58
        whiten: bool = True,
59
        normalize: bool = True,
60
        make_skewness_positive: bool = True,
61
        values_range: tuple[float, float] | None = (-2.3, 2.3),
62
        clip: bool = True,
63
        pca_kwargs: dict[str, Any] | None = None,
64
        **kwargs,
65
    ):
66
        super().__init__(**kwargs)
67
        self.num_components = num_components
68
        self.keep_components = keep_components
69
        self.whiten = whiten
70
        self.normalize = normalize
71
        self.make_skewness_positive = make_skewness_positive
72
        self.values_range = values_range
73
        self.clip = clip
74
        self.pca_kwargs = pca_kwargs
75
        self.args_names = [
76
            'num_components',
77
            'keep_components',
78
            'whiten',
79
            'normalize',
80
            'make_skewness_positive',
81
            'values_range',
82
            'clip',
83
            'pca_kwargs',
84
        ]
85
86
    def apply_transform(self, subject: Subject) -> Subject:
87
        for image in self.get_images(subject):
88
            kwargs = {} if self.pca_kwargs is None else self.pca_kwargs
89
            pca_image = _compute_pca(
90
                image,
91
                num_components=self.num_components,
92
                keep_components=self.keep_components,
93
                whiten=self.whiten,
94
                normalize=self.normalize,
95
                make_skewness_positive=self.make_skewness_positive,
96
                values_range=self.values_range,
97
                clip=self.clip,
98
                **kwargs,
99
            )
100
            image.set_data(pca_image.data)
101
        return subject
102
103
104
def _compute_pca(
105
    embeddings: ScalarImage,
106
    num_components: int,
107
    keep_components: int | None,
108
    whiten: bool,
109
    normalize: bool,
110
    make_skewness_positive: bool,
111
    values_range: tuple[float, float] | None,
112
    clip: bool,
113
    **pca_kwargs,
114
) -> ScalarImage:
115
    # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py
116
    # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1].
117
    sklearn = get_sklearn()
118
    PCA = sklearn.decomposition.PCA
119
120
    data = embeddings.numpy()
121
    _, size_x, size_y, size_z = data.shape
122
    X = rearrange(data, 'c x y z -> (x y z) c')
123
    pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs)
124
    projected: np.ndarray = pca.fit_transform(X).T
125
    if normalize:
126
        projected /= projected[0].std()
127
    if make_skewness_positive:
128
        for component in projected:
129
            third_cumulant = np.mean(component**3)
130
            second_cumulant = np.mean(component**2)
131
            skewness = third_cumulant / second_cumulant ** (3 / 2)
132
            if skewness < 0:
133
                component *= -1
134
    grid: np.ndarray = rearrange(
135
        projected,
136
        'c (x y z) -> c x y z',
137
        x=size_x,
138
        y=size_y,
139
        z=size_z,
140
    )
141
    if values_range is not None:
142
        vmin, vmax = values_range
143
    else:
144
        vmin, vmax = grid.min(), grid.max()
145
    grid = (grid - vmin) / (vmax - vmin)
146
    if clip:
147
        grid = np.clip(grid, 0, 1)
148
    if keep_components is not None:
149
        grid = grid[:keep_components]
150
    return ScalarImage(tensor=grid, affine=embeddings.affine)
151