|
1
|
|
|
import numpy as np |
|
2
|
|
|
import torch |
|
3
|
|
|
from einops import rearrange |
|
4
|
|
|
|
|
5
|
|
|
from ....data.image import Image |
|
6
|
|
|
from ....data.image import ScalarImage |
|
7
|
|
|
from ....external.imports import get_sklearn |
|
8
|
|
|
from ....visualization import build_image_from_input_and_output |
|
9
|
|
|
|
|
10
|
|
|
|
|
11
|
|
|
def _pca( |
|
12
|
|
|
data: torch.Tensor, |
|
13
|
|
|
num_components: int = 6, |
|
14
|
|
|
whiten: bool = True, |
|
15
|
|
|
clip_range: tuple[float, float] | None = (-2.3, 2.3), |
|
16
|
|
|
normalize: bool = True, |
|
17
|
|
|
make_skewness_positive: bool = True, |
|
18
|
|
|
**pca_kwargs, |
|
19
|
|
|
) -> torch.Tensor: |
|
20
|
|
|
# Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py |
|
21
|
|
|
# 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1]. |
|
22
|
|
|
sklearn = get_sklearn() |
|
23
|
|
|
PCA = sklearn.decomposition.PCA |
|
24
|
|
|
|
|
25
|
|
|
_, size_x, size_y, size_z = data.shape |
|
26
|
|
|
X = rearrange(data, 'c x y z -> (x y z) c') |
|
27
|
|
|
pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs) |
|
28
|
|
|
projected: np.ndarray = pca.fit_transform(X).T |
|
29
|
|
|
if normalize: |
|
30
|
|
|
projected /= projected[0].std() |
|
31
|
|
|
if make_skewness_positive: |
|
32
|
|
|
for component in projected: |
|
33
|
|
|
third_cumulant = np.mean(component**3) |
|
34
|
|
|
second_cumulant = np.mean(component**2) |
|
35
|
|
|
skewness = third_cumulant / second_cumulant ** (3 / 2) |
|
36
|
|
|
if skewness < 0: |
|
37
|
|
|
component *= -1 |
|
38
|
|
|
grid: np.ndarray = rearrange( |
|
39
|
|
|
projected.T, |
|
40
|
|
|
'(x y z) c -> c x y z', |
|
41
|
|
|
x=size_x, |
|
42
|
|
|
y=size_y, |
|
43
|
|
|
z=size_z, |
|
44
|
|
|
) |
|
45
|
|
|
if clip_range is not None: |
|
46
|
|
|
vmin, vmax = clip_range |
|
47
|
|
|
else: |
|
48
|
|
|
vmin, vmax = grid.min(), grid.max() |
|
49
|
|
|
grid = (grid - vmin) / (vmax - vmin) |
|
50
|
|
|
return torch.from_numpy(grid.clip(0, 1)) |
|
51
|
|
|
|
|
52
|
|
|
|
|
53
|
|
|
def build_pca_image( |
|
54
|
|
|
embeddings: torch.Tensor, image: Image, keep_components: int | None = 3 |
|
55
|
|
|
) -> ScalarImage: |
|
56
|
|
|
pca = _pca(embeddings) |
|
57
|
|
|
if keep_components is not None: |
|
58
|
|
|
pca = pca[:keep_components] |
|
59
|
|
|
pca_image = build_image_from_input_and_output(pca, image) |
|
60
|
|
|
return pca_image |
|
61
|
|
|
|