Passed
Pull Request — main (#1349)
by Fernando
01:33
created

PCA.__init__()   A

Complexity

Conditions 1

Size

Total Lines 31
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 30
nop 11
dl 0
loc 31
rs 9.16
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from typing import Any
2
3
import numpy as np
4
from einops import rearrange
5
6
from ....data.image import ScalarImage
7
from ....data.subject import Subject
8
from ....external.imports import get_sklearn
9
from ...intensity_transform import IntensityTransform
10
11
12
class PCA(IntensityTransform):
13
    """Compute principal component analysis (PCA) of an image.
14
15
    PCA can be useful to visualize embeddings generated by a neural network.
16
    See for example Figure 8 in `Cluster and Predict Latent Patches for
17
    Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_.
18
19
    Args:
20
        num_components: Number of components to compute.
21
        keep_components: Number of components to keep in the output image.
22
            If ``None``, all components are kept.
23
        whiten: If ``True``, the components are normalized to have unit variance.
24
        normalize: If ``True``, all components are divided by the standard
25
            deviation of the first component.
26
        make_skewness_positive: If ``True``, the skewness of each component is
27
            made positive by multiplying the component by -1 if its skewness is
28
            negative.
29
        values_range: If not ``None``, these values are linearly mappped to
30
            :math:`[0, 1]`.
31
        clip: If ``True``, the output values are clipped to :math:`[0, 1]`.
32
        pca_kwargs: Additional keyword arguments to pass to
33
            :class:`sklearn.decomposition.PCA`.
34
35
    Example:
36
37
    >>> import torchio as tio
38
    >>> from torchio.visualization import build_image_from_reference
39
    >>> ct = my_preprocessed_ct_image  # Assume this is a preprocessed CT image
40
    >>> ct
41
    ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB)
42
    >>> embedding_tensor = model(ct.data[None])[0]  # `model` is some pre-trained neural network
43
    >>> embedding_image = ToReferenceSpace(ct)(embedding_tensor)
44
    >>> embedding_image
45
    ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB)
46
    >>> pca = tio.PCA()(embedding_image)
47
    >>> pca
48
    ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB)
49
    """
50
51
    def __init__(
52
        self,
53
        num_components: int = 6,
54
        *,
55
        keep_components: int | None = 3,
56
        whiten: bool = True,
57
        normalize: bool = True,
58
        make_skewness_positive: bool = True,
59
        values_range: tuple[float, float] | None = (-2.3, 2.3),
60
        clip: bool = True,
61
        pca_kwargs: dict[str, Any] | None = None,
62
        **kwargs,
63
    ):
64
        super().__init__(**kwargs)
65
        self.num_components = num_components
66
        self.keep_components = keep_components
67
        self.whiten = whiten
68
        self.normalize = normalize
69
        self.make_skewness_positive = make_skewness_positive
70
        self.values_range = values_range
71
        self.clip = clip
72
        self.pca_kwargs = pca_kwargs
73
        self.args_names = [
74
            'num_components',
75
            'keep_components',
76
            'whiten',
77
            'normalize',
78
            'make_skewness_positive',
79
            'values_range',
80
            'clip',
81
            'pca_kwargs',
82
        ]
83
84
    def apply_transform(self, subject: Subject) -> Subject:
85
        for image in self.get_images(subject):
86
            kwargs = {} if self.pca_kwargs is None else self.pca_kwargs
87
            pca_image = _compute_pca(
88
                image,
89
                num_components=self.num_components,
90
                keep_components=self.keep_components,
91
                whiten=self.whiten,
92
                normalize=self.normalize,
93
                make_skewness_positive=self.make_skewness_positive,
94
                values_range=self.values_range,
95
                clip=self.clip,
96
                **kwargs,
97
            )
98
            image.set_data(pca_image.data)
99
        return subject
100
101
102
def _compute_pca(
103
    embeddings: ScalarImage,
104
    num_components: int,
105
    keep_components: int | None,
106
    whiten: bool,
107
    normalize: bool,
108
    make_skewness_positive: bool,
109
    values_range: tuple[float, float] | None,
110
    clip: bool,
111
    **pca_kwargs,
112
) -> ScalarImage:
113
    # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py
114
    # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1].
115
    sklearn = get_sklearn()
116
    PCA = sklearn.decomposition.PCA
117
118
    data = embeddings.numpy()
119
    _, size_x, size_y, size_z = data.shape
120
    X = rearrange(data, 'c x y z -> (x y z) c')
121
    pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs)
122
    projected: np.ndarray = pca.fit_transform(X).T
123
    if normalize:
124
        projected /= projected[0].std()
125
    if make_skewness_positive:
126
        for component in projected:
127
            third_cumulant = np.mean(component**3)
128
            second_cumulant = np.mean(component**2)
129
            skewness = third_cumulant / second_cumulant ** (3 / 2)
130
            if skewness < 0:
131
                component *= -1
132
    grid: np.ndarray = rearrange(
133
        projected,
134
        'c (x y z) -> c x y z',
135
        x=size_x,
136
        y=size_y,
137
        z=size_z,
138
    )
139
    if values_range is not None:
140
        vmin, vmax = values_range
141
    else:
142
        vmin, vmax = grid.min(), grid.max()
143
    grid = (grid - vmin) / (vmax - vmin)
144
    if clip:
145
        grid = np.clip(grid, 0, 1)
146
    if keep_components is not None:
147
        grid = grid[:keep_components]
148
    return ScalarImage(tensor=grid, affine=embeddings.affine)
149