|
1
|
|
|
from __future__ import annotations |
|
2
|
|
|
|
|
3
|
|
|
from typing import Any |
|
4
|
|
|
|
|
5
|
|
|
import numpy as np |
|
6
|
|
|
from einops import rearrange |
|
7
|
|
|
|
|
8
|
|
|
from ....data.image import ScalarImage |
|
9
|
|
|
from ....data.subject import Subject |
|
10
|
|
|
from ....external.imports import get_sklearn |
|
11
|
|
|
from ...intensity_transform import IntensityTransform |
|
12
|
|
|
|
|
13
|
|
|
|
|
14
|
|
|
class PCA(IntensityTransform): |
|
15
|
|
|
"""Compute principal component analysis (PCA) of an image. |
|
16
|
|
|
|
|
17
|
|
|
PCA can be useful to visualize embeddings generated by a neural network. |
|
18
|
|
|
See for example Figure 8 in `Cluster and Predict Latent Patches for |
|
19
|
|
|
Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_. |
|
20
|
|
|
|
|
21
|
|
|
Args: |
|
22
|
|
|
num_components: Number of components to compute. |
|
23
|
|
|
keep_components: Number of components to keep in the output image. |
|
24
|
|
|
If ``None``, all components are kept. |
|
25
|
|
|
whiten: If ``True``, the components are normalized to have unit variance. |
|
26
|
|
|
normalize: If ``True``, all components are divided by the standard |
|
27
|
|
|
deviation of the first component. |
|
28
|
|
|
make_skewness_positive: If ``True``, the skewness of each component is |
|
29
|
|
|
made positive by multiplying the component by -1 if its skewness is |
|
30
|
|
|
negative. |
|
31
|
|
|
values_range: If not ``None``, these values are linearly mappped to |
|
32
|
|
|
:math:`[0, 1]`. |
|
33
|
|
|
clip: If ``True``, the output values are clipped to :math:`[0, 1]`. |
|
34
|
|
|
pca_kwargs: Additional keyword arguments to pass to |
|
35
|
|
|
:class:`sklearn.decomposition.PCA`. |
|
36
|
|
|
|
|
37
|
|
|
Example: |
|
38
|
|
|
|
|
39
|
|
|
>>> import torchio as tio |
|
40
|
|
|
>>> from torchio.visualization import build_image_from_reference |
|
41
|
|
|
>>> ct = my_preprocessed_ct_image # Assume this is a preprocessed CT image |
|
42
|
|
|
>>> ct |
|
43
|
|
|
ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB) |
|
44
|
|
|
>>> embedding_tensor = model(ct.data[None])[0] # `model` is some pre-trained neural network |
|
45
|
|
|
>>> embedding_image = ToReferenceSpace(ct)(embedding_tensor) |
|
46
|
|
|
>>> embedding_image |
|
47
|
|
|
ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB) |
|
48
|
|
|
>>> pca = tio.PCA()(embedding_image) |
|
49
|
|
|
>>> pca |
|
50
|
|
|
ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB) |
|
51
|
|
|
""" |
|
52
|
|
|
|
|
53
|
|
|
def __init__( |
|
54
|
|
|
self, |
|
55
|
|
|
num_components: int = 6, |
|
56
|
|
|
*, |
|
57
|
|
|
keep_components: int | None = 3, |
|
58
|
|
|
whiten: bool = True, |
|
59
|
|
|
normalize: bool = True, |
|
60
|
|
|
make_skewness_positive: bool = True, |
|
61
|
|
|
values_range: tuple[float, float] | None = (-2.3, 2.3), |
|
62
|
|
|
clip: bool = True, |
|
63
|
|
|
pca_kwargs: dict[str, Any] | None = None, |
|
64
|
|
|
**kwargs, |
|
65
|
|
|
): |
|
66
|
|
|
super().__init__(**kwargs) |
|
67
|
|
|
self.num_components = num_components |
|
68
|
|
|
self.keep_components = keep_components |
|
69
|
|
|
self.whiten = whiten |
|
70
|
|
|
self.normalize = normalize |
|
71
|
|
|
self.make_skewness_positive = make_skewness_positive |
|
72
|
|
|
self.values_range = values_range |
|
73
|
|
|
self.clip = clip |
|
74
|
|
|
self.pca_kwargs = pca_kwargs |
|
75
|
|
|
self.args_names = [ |
|
76
|
|
|
'num_components', |
|
77
|
|
|
'keep_components', |
|
78
|
|
|
'whiten', |
|
79
|
|
|
'normalize', |
|
80
|
|
|
'make_skewness_positive', |
|
81
|
|
|
'values_range', |
|
82
|
|
|
'clip', |
|
83
|
|
|
'pca_kwargs', |
|
84
|
|
|
] |
|
85
|
|
|
|
|
86
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
87
|
|
|
for image in self.get_images(subject): |
|
88
|
|
|
kwargs = {} if self.pca_kwargs is None else self.pca_kwargs |
|
89
|
|
|
pca_image = _compute_pca( |
|
90
|
|
|
image, |
|
91
|
|
|
num_components=self.num_components, |
|
92
|
|
|
keep_components=self.keep_components, |
|
93
|
|
|
whiten=self.whiten, |
|
94
|
|
|
normalize=self.normalize, |
|
95
|
|
|
make_skewness_positive=self.make_skewness_positive, |
|
96
|
|
|
values_range=self.values_range, |
|
97
|
|
|
clip=self.clip, |
|
98
|
|
|
**kwargs, |
|
99
|
|
|
) |
|
100
|
|
|
image.set_data(pca_image.data) |
|
101
|
|
|
return subject |
|
102
|
|
|
|
|
103
|
|
|
|
|
104
|
|
|
def _compute_pca( |
|
105
|
|
|
embeddings: ScalarImage, |
|
106
|
|
|
num_components: int, |
|
107
|
|
|
keep_components: int | None, |
|
108
|
|
|
whiten: bool, |
|
109
|
|
|
normalize: bool, |
|
110
|
|
|
make_skewness_positive: bool, |
|
111
|
|
|
values_range: tuple[float, float] | None, |
|
112
|
|
|
clip: bool, |
|
113
|
|
|
**pca_kwargs, |
|
114
|
|
|
) -> ScalarImage: |
|
115
|
|
|
# Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py |
|
116
|
|
|
# 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1]. |
|
117
|
|
|
sklearn = get_sklearn() |
|
118
|
|
|
PCA = sklearn.decomposition.PCA |
|
119
|
|
|
|
|
120
|
|
|
data = embeddings.numpy() |
|
121
|
|
|
_, size_x, size_y, size_z = data.shape |
|
122
|
|
|
X = rearrange(data, 'c x y z -> (x y z) c') |
|
123
|
|
|
pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs) |
|
124
|
|
|
projected: np.ndarray = pca.fit_transform(X).T |
|
125
|
|
|
if normalize: |
|
126
|
|
|
projected /= projected[0].std() |
|
127
|
|
|
if make_skewness_positive: |
|
128
|
|
|
for component in projected: |
|
129
|
|
|
third_cumulant = np.mean(component**3) |
|
130
|
|
|
second_cumulant = np.mean(component**2) |
|
131
|
|
|
skewness = third_cumulant / second_cumulant ** (3 / 2) |
|
132
|
|
|
if skewness < 0: |
|
133
|
|
|
component *= -1 |
|
134
|
|
|
grid: np.ndarray = rearrange( |
|
135
|
|
|
projected, |
|
136
|
|
|
'c (x y z) -> c x y z', |
|
137
|
|
|
x=size_x, |
|
138
|
|
|
y=size_y, |
|
139
|
|
|
z=size_z, |
|
140
|
|
|
) |
|
141
|
|
|
if values_range is not None: |
|
142
|
|
|
vmin, vmax = values_range |
|
143
|
|
|
else: |
|
144
|
|
|
vmin, vmax = grid.min(), grid.max() |
|
145
|
|
|
grid = (grid - vmin) / (vmax - vmin) |
|
146
|
|
|
if clip: |
|
147
|
|
|
grid = np.clip(grid, 0, 1) |
|
148
|
|
|
if keep_components is not None: |
|
149
|
|
|
grid = grid[:keep_components] |
|
150
|
|
|
return ScalarImage(tensor=grid, affine=embeddings.affine) |
|
151
|
|
|
|