Passed
Push — main ( aa4098...a77a3c )
by Fernando
01:29
created

torchio.transforms.preprocessing.spatial.to_reference_space   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 55
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 33
dl 0
loc 55
rs 10
c 0
b 0
f 0
wmc 6

3 Methods

Rating   Name   Duplication   Size   Complexity  
A ToReferenceSpace.from_tensor() 0 4 1
A ToReferenceSpace.apply_transform() 0 6 2
A ToReferenceSpace.__init__() 0 5 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A build_image_from_reference() 0 11 1
1
import numpy as np
2
import torch
3
4
from ....data.image import Image
5
from ....data.subject import Subject
6
from ...spatial_transform import SpatialTransform
7
from .resample import Resample
8
9
10
class ToReferenceSpace(SpatialTransform):
11
    """Modify the spatial metadata so it matches a reference space.
12
13
    This is useful, for example, to set meaningful spatial metadata of a neural
14
    network embedding, for visualization or further processing such as
15
    resampling a segmentation output.
16
17
    Example:
18
19
    >>> import torchio as tio
20
    >>> image = tio.datasets.FPG().t1
21
    >>> embedding_tensor = my_network(image.tensor)  # we lose metadata here
22
    >>> embedding_image = tio.ToReferenceSpace.from_tensor(embedding_tensor, image)
23
    """
24
25
    def __init__(self, reference: Image, **kwargs):
26
        super().__init__(**kwargs)
27
        if not isinstance(reference, Image):
28
            raise TypeError('The reference must be a TorchIO image')
29
        self.reference = reference
30
31
    def apply_transform(self, subject: Subject) -> Subject:
32
        for image in self.get_images(subject):
33
            new_image = build_image_from_reference(image.data, self.reference)
34
            image.set_data(new_image.data)
35
            image.affine = new_image.affine
36
        return subject
37
38
    @staticmethod
39
    def from_tensor(tensor: torch.Tensor, reference: Image) -> Image:
40
        """Build a TorchIO image from a tensor and a reference image."""
41
        return build_image_from_reference(tensor, reference)
42
43
44
def build_image_from_reference(tensor: torch.Tensor, reference: Image) -> Image:
45
    input_shape = np.array(reference.spatial_shape)
46
    output_shape = np.array(tensor.shape[-3:])
47
    downsampling_factor = input_shape / output_shape
48
    input_spacing = np.array(reference.spacing)
49
    output_spacing = input_spacing * downsampling_factor
50
    downsample = Resample(output_spacing, image_interpolation='nearest')
51
    reference = downsample(reference)
52
    class_ = reference.__class__
53
    result = class_(tensor=tensor, affine=reference.affine)
54
    return result
55