1
|
|
|
import numpy as np |
2
|
|
|
import torch |
3
|
|
|
from einops import rearrange |
4
|
|
|
|
5
|
|
|
from ....data.image import Image |
6
|
|
|
from ....data.image import ScalarImage |
7
|
|
|
from ....external.imports import get_sklearn |
8
|
|
|
from ....visualization import build_image_from_input_and_output |
9
|
|
|
|
10
|
|
|
|
11
|
|
|
def _pca( |
12
|
|
|
data: torch.Tensor, |
13
|
|
|
num_components: int = 6, |
14
|
|
|
whiten: bool = True, |
15
|
|
|
clip_range: tuple[float, float] | None = (-2.3, 2.3), |
16
|
|
|
normalize: bool = True, |
17
|
|
|
make_skewness_positive: bool = True, |
18
|
|
|
**pca_kwargs, |
19
|
|
|
) -> torch.Tensor: |
20
|
|
|
# Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py |
21
|
|
|
# 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1]. |
22
|
|
|
sklearn = get_sklearn() |
23
|
|
|
PCA = sklearn.decomposition.PCA |
24
|
|
|
|
25
|
|
|
_, size_x, size_y, size_z = data.shape |
26
|
|
|
X = rearrange(data, 'c x y z -> (x y z) c') |
27
|
|
|
pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs) |
28
|
|
|
projected: np.ndarray = pca.fit_transform(X).T |
29
|
|
|
if normalize: |
30
|
|
|
projected /= projected[0].std() |
31
|
|
|
if make_skewness_positive: |
32
|
|
|
for component in projected: |
33
|
|
|
third_cumulant = np.mean(component**3) |
34
|
|
|
second_cumulant = np.mean(component**2) |
35
|
|
|
skewness = third_cumulant / second_cumulant ** (3 / 2) |
36
|
|
|
if skewness < 0: |
37
|
|
|
component *= -1 |
38
|
|
|
grid: np.ndarray = rearrange( |
39
|
|
|
projected.T, |
40
|
|
|
'(x y z) c -> c x y z', |
41
|
|
|
x=size_x, |
42
|
|
|
y=size_y, |
43
|
|
|
z=size_z, |
44
|
|
|
) |
45
|
|
|
if clip_range is not None: |
46
|
|
|
vmin, vmax = clip_range |
47
|
|
|
else: |
48
|
|
|
vmin, vmax = grid.min(), grid.max() |
49
|
|
|
grid = (grid - vmin) / (vmax - vmin) |
50
|
|
|
return torch.from_numpy(grid.clip(0, 1)) |
51
|
|
|
|
52
|
|
|
|
53
|
|
|
def build_pca_image( |
54
|
|
|
embeddings: torch.Tensor, image: Image, keep_components: int | None = 3 |
55
|
|
|
) -> ScalarImage: |
56
|
|
|
pca = _pca(embeddings) |
57
|
|
|
if keep_components is not None: |
58
|
|
|
pca = pca[:keep_components] |
59
|
|
|
pca_image = build_image_from_input_and_output(pca, image) |
60
|
|
|
return pca_image |
61
|
|
|
|