1
|
|
|
from __future__ import annotations |
2
|
|
|
|
3
|
|
|
from typing import Any |
4
|
|
|
|
5
|
|
|
import numpy as np |
6
|
|
|
from einops import rearrange |
7
|
|
|
|
8
|
|
|
from ....data.image import ScalarImage |
9
|
|
|
from ....data.subject import Subject |
10
|
|
|
from ....external.imports import get_sklearn |
11
|
|
|
from ...intensity_transform import IntensityTransform |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
class PCA(IntensityTransform): |
15
|
|
|
"""Compute principal component analysis (PCA) of an image. |
16
|
|
|
|
17
|
|
|
PCA can be useful to visualize embeddings generated by a neural network. |
18
|
|
|
See for example Figure 8 in `Cluster and Predict Latent Patches for |
19
|
|
|
Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_. |
20
|
|
|
|
21
|
|
|
Args: |
22
|
|
|
num_components: Number of components to compute. |
23
|
|
|
keep_components: Number of components to keep in the output image. |
24
|
|
|
If ``None``, all components are kept. |
25
|
|
|
whiten: If ``True``, the components are normalized to have unit variance. |
26
|
|
|
normalize: If ``True``, all components are divided by the standard |
27
|
|
|
deviation of the first component. |
28
|
|
|
make_skewness_positive: If ``True``, the skewness of each component is |
29
|
|
|
made positive by multiplying the component by -1 if its skewness is |
30
|
|
|
negative. |
31
|
|
|
values_range: If not ``None``, these values are linearly mappped to |
32
|
|
|
:math:`[0, 1]`. |
33
|
|
|
clip: If ``True``, the output values are clipped to :math:`[0, 1]`. |
34
|
|
|
pca_kwargs: Additional keyword arguments to pass to |
35
|
|
|
:class:`sklearn.decomposition.PCA`. |
36
|
|
|
|
37
|
|
|
Example: |
38
|
|
|
|
39
|
|
|
>>> import torchio as tio |
40
|
|
|
>>> from torchio.visualization import build_image_from_reference |
41
|
|
|
>>> ct = my_preprocessed_ct_image # Assume this is a preprocessed CT image |
42
|
|
|
>>> ct |
43
|
|
|
ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB) |
44
|
|
|
>>> embedding_tensor = model(ct.data[None])[0] # `model` is some pre-trained neural network |
45
|
|
|
>>> embedding_image = ToReferenceSpace(ct)(embedding_tensor) |
46
|
|
|
>>> embedding_image |
47
|
|
|
ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB) |
48
|
|
|
>>> pca = tio.PCA()(embedding_image) |
49
|
|
|
>>> pca |
50
|
|
|
ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB) |
51
|
|
|
""" |
52
|
|
|
|
53
|
|
|
def __init__( |
54
|
|
|
self, |
55
|
|
|
num_components: int = 6, |
56
|
|
|
*, |
57
|
|
|
keep_components: int | None = 3, |
58
|
|
|
whiten: bool = True, |
59
|
|
|
normalize: bool = True, |
60
|
|
|
make_skewness_positive: bool = True, |
61
|
|
|
values_range: tuple[float, float] | None = (-2.3, 2.3), |
62
|
|
|
clip: bool = True, |
63
|
|
|
pca_kwargs: dict[str, Any] | None = None, |
64
|
|
|
**kwargs, |
65
|
|
|
): |
66
|
|
|
super().__init__(**kwargs) |
67
|
|
|
self.num_components = num_components |
68
|
|
|
self.keep_components = keep_components |
69
|
|
|
self.whiten = whiten |
70
|
|
|
self.normalize = normalize |
71
|
|
|
self.make_skewness_positive = make_skewness_positive |
72
|
|
|
self.values_range = values_range |
73
|
|
|
self.clip = clip |
74
|
|
|
self.pca_kwargs = pca_kwargs |
75
|
|
|
self.args_names = [ |
76
|
|
|
'num_components', |
77
|
|
|
'keep_components', |
78
|
|
|
'whiten', |
79
|
|
|
'normalize', |
80
|
|
|
'make_skewness_positive', |
81
|
|
|
'values_range', |
82
|
|
|
'clip', |
83
|
|
|
'pca_kwargs', |
84
|
|
|
] |
85
|
|
|
|
86
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
87
|
|
|
for image in self.get_images(subject): |
88
|
|
|
kwargs = {} if self.pca_kwargs is None else self.pca_kwargs |
89
|
|
|
pca_image = _compute_pca( |
90
|
|
|
image, |
91
|
|
|
num_components=self.num_components, |
92
|
|
|
keep_components=self.keep_components, |
93
|
|
|
whiten=self.whiten, |
94
|
|
|
normalize=self.normalize, |
95
|
|
|
make_skewness_positive=self.make_skewness_positive, |
96
|
|
|
values_range=self.values_range, |
97
|
|
|
clip=self.clip, |
98
|
|
|
**kwargs, |
99
|
|
|
) |
100
|
|
|
image.set_data(pca_image.data) |
101
|
|
|
return subject |
102
|
|
|
|
103
|
|
|
|
104
|
|
|
def _compute_pca( |
105
|
|
|
embeddings: ScalarImage, |
106
|
|
|
num_components: int, |
107
|
|
|
keep_components: int | None, |
108
|
|
|
whiten: bool, |
109
|
|
|
normalize: bool, |
110
|
|
|
make_skewness_positive: bool, |
111
|
|
|
values_range: tuple[float, float] | None, |
112
|
|
|
clip: bool, |
113
|
|
|
**pca_kwargs, |
114
|
|
|
) -> ScalarImage: |
115
|
|
|
# Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py |
116
|
|
|
# 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1]. |
117
|
|
|
sklearn = get_sklearn() |
118
|
|
|
PCA = sklearn.decomposition.PCA |
119
|
|
|
|
120
|
|
|
data = embeddings.numpy() |
121
|
|
|
_, size_x, size_y, size_z = data.shape |
122
|
|
|
X = rearrange(data, 'c x y z -> (x y z) c') |
123
|
|
|
pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs) |
124
|
|
|
projected: np.ndarray = pca.fit_transform(X).T |
125
|
|
|
if normalize: |
126
|
|
|
projected /= projected[0].std() |
127
|
|
|
if make_skewness_positive: |
128
|
|
|
for component in projected: |
129
|
|
|
third_cumulant = np.mean(component**3) |
130
|
|
|
second_cumulant = np.mean(component**2) |
131
|
|
|
skewness = third_cumulant / second_cumulant ** (3 / 2) |
132
|
|
|
if skewness < 0: |
133
|
|
|
component *= -1 |
134
|
|
|
grid: np.ndarray = rearrange( |
135
|
|
|
projected, |
136
|
|
|
'c (x y z) -> c x y z', |
137
|
|
|
x=size_x, |
138
|
|
|
y=size_y, |
139
|
|
|
z=size_z, |
140
|
|
|
) |
141
|
|
|
if values_range is not None: |
142
|
|
|
vmin, vmax = values_range |
143
|
|
|
else: |
144
|
|
|
vmin, vmax = grid.min(), grid.max() |
145
|
|
|
grid = (grid - vmin) / (vmax - vmin) |
146
|
|
|
if clip: |
147
|
|
|
grid = np.clip(grid, 0, 1) |
148
|
|
|
if keep_components is not None: |
149
|
|
|
grid = grid[:keep_components] |
150
|
|
|
return ScalarImage(tensor=grid, affine=embeddings.affine) |
151
|
|
|
|