Passed
Push — main ( c95fa1...287682 )
by Fernando
01:32
created

To.__init__()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 4
dl 0
loc 12
rs 9.85
c 0
b 0
f 0
1
from __future__ import annotations
2
3
from typing import Any
4
5
import torch
6
7
from ....data.image import ScalarImage
8
from ....data.subject import Subject
9
from ...intensity_transform import IntensityTransform
10
11
12
class To(IntensityTransform):
13
    """Convert the image tensor data type and/or device.
14
15
    This transform is a thin wrapper around :func:`torch.Tensor.to`.
16
17
    Args:
18
        target: First argument to :func:`torch.Tensor.to`.
19
        to_kwargs: Additional keyword arguments to pass to :func:`torch.Tensor.to`.
20
21
    Example:
22
        >>> import torchio as tio
23
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
24
        >>> clamp = tio.Clamp(out_min=-1000, out_max=1000)
25
        >>> ct_clamped = clamp(ct)
26
        >>> rescale = tio.RescaleIntensity(in_min_max=(-1000, 1000), out_min_max=(0, 255))
27
        >>> ct_rescaled = rescale(ct_clamped)
28
        >>> to_uint8 = tio.To(torch.uint8)
29
        >>> ct_uint8 = to_uint8(ct_rescaled)
30
    """
31
32
    def __init__(
33
        self,
34
        target: str | torch.dtype | torch.device,
35
        to_kwargs: dict[str, Any] | None = None,
36
        **kwargs,
37
    ):
38
        super().__init__(**kwargs)
39
        self.target = target
40
        if to_kwargs is None:
41
            to_kwargs = {}
42
        self.to_kwargs = to_kwargs
43
        self.args_names = ['target', 'to_kwargs']
44
45
    def apply_transform(self, subject: Subject) -> Subject:
46
        for image in self.get_images(subject):
47
            assert isinstance(image, ScalarImage)
48
            image.set_data(image.data.to(self.target, **self.to_kwargs))
49
        return subject
50