Passed
Pull Request — main (#1345)
by Fernando
01:40
created

torchio.transforms.preprocessing.intensity.to   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 47
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 47
rs 10
c 0
b 0
f 0
wmc 3

2 Methods

Rating   Name   Duplication   Size   Complexity  
A To.apply_transform() 0 5 2
A To.__init__() 0 10 1
1
from __future__ import annotations
2
3
from typing import Any
4
5
import torch
6
7
from ....data.image import ScalarImage
8
from ....data.subject import Subject
9
from ...intensity_transform import IntensityTransform
10
11
12
class To(IntensityTransform):
13
    """Convert the image tensor data type and/or device.
14
15
    Args:
16
        target: First argument to :func:`torch.Tensor.to`.
17
        to_kwargs: Additional keyword arguments to pass to :func:`torch.Tensor.to`.
18
19
    Example:
20
        >>> import torchio as tio
21
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
22
        >>> clamp = tio.Clamp(out_min=-1000, out_max=1000)
23
        >>> ct_clamped = clamp(ct)
24
        >>> rescale = tio.RescaleIntensity(in_min_max=(-1000, 1000), out_min_max=(0, 255))
25
        >>> ct_rescaled = rescale(ct_clamped)
26
        >>> to_uint8 = tio.To(torch.uint8)
27
        >>> ct_uint8 = to_uint8(ct_rescaled)
28
29
    """
30
31
    def __init__(
32
        self,
33
        target: str | torch.dtype | torch.device,
34
        to_kwargs: dict[str, Any],
35
        **kwargs,
36
    ):
37
        super().__init__(**kwargs)
38
        self.target = target
39
        self.to_kwargs = to_kwargs
40
        self.args_names = ['target', 'to_kwargs']
41
42
    def apply_transform(self, subject: Subject) -> Subject:
43
        for image in self.get_images(subject):
44
            assert isinstance(image, ScalarImage)
45
            image.set_data(image.data.to(self.target, **self.to_kwargs))
46
        return subject
47