Passed
Pull Request — main (#1349)
by Fernando
01:29
created

torchio.transforms.preprocessing.intensity.pca   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 61
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 49
dl 0
loc 61
rs 10
c 0
b 0
f 0
wmc 8

2 Functions

Rating   Name   Duplication   Size   Complexity  
B _pca() 0 40 6
A build_pca_image() 0 8 2
1
import numpy as np
2
import torch
3
from einops import rearrange
4
5
from ....data.image import Image
6
from ....data.image import ScalarImage
7
from ....external.imports import get_sklearn
8
from ....visualization import build_image_from_input_and_output
9
10
11
def _pca(
12
    data: torch.Tensor,
13
    num_components: int = 6,
14
    whiten: bool = True,
15
    clip_range: tuple[float, float] | None = (-2.3, 2.3),
16
    normalize: bool = True,
17
    make_skewness_positive: bool = True,
18
    **pca_kwargs,
19
) -> torch.Tensor:
20
    # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py
21
    # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1].
22
    sklearn = get_sklearn()
23
    PCA = sklearn.decomposition.PCA
24
25
    _, size_x, size_y, size_z = data.shape
26
    X = rearrange(data, 'c x y z -> (x y z) c')
27
    pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs)
28
    projected: np.ndarray = pca.fit_transform(X).T
29
    if normalize:
30
        projected /= projected[0].std()
31
    if make_skewness_positive:
32
        for component in projected:
33
            third_cumulant = np.mean(component**3)
34
            second_cumulant = np.mean(component**2)
35
            skewness = third_cumulant / second_cumulant ** (3 / 2)
36
            if skewness < 0:
37
                component *= -1
38
    grid: np.ndarray = rearrange(
39
        projected.T,
40
        '(x y z) c -> c x y z',
41
        x=size_x,
42
        y=size_y,
43
        z=size_z,
44
    )
45
    if clip_range is not None:
46
        vmin, vmax = clip_range
47
    else:
48
        vmin, vmax = grid.min(), grid.max()
49
    grid = (grid - vmin) / (vmax - vmin)
50
    return torch.from_numpy(grid.clip(0, 1))
51
52
53
def build_pca_image(
54
    embeddings: torch.Tensor, image: Image, keep_components: int | None = 3
55
) -> ScalarImage:
56
    pca = _pca(embeddings)
57
    if keep_components is not None:
58
        pca = pca[:keep_components]
59
    pca_image = build_image_from_input_and_output(pca, image)
60
    return pca_image
61