1
|
|
|
from __future__ import annotations |
2
|
|
|
|
3
|
|
|
from typing import Any |
4
|
|
|
|
5
|
|
|
import torch |
6
|
|
|
|
7
|
|
|
from ....data.image import ScalarImage |
8
|
|
|
from ....data.subject import Subject |
9
|
|
|
from ...intensity_transform import IntensityTransform |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
class To(IntensityTransform): |
13
|
|
|
"""Convert the image tensor data type and/or device. |
14
|
|
|
|
15
|
|
|
This transform is a thin wrapper around :func:`torch.Tensor.to`. |
16
|
|
|
|
17
|
|
|
Args: |
18
|
|
|
target: First argument to :func:`torch.Tensor.to`. |
19
|
|
|
to_kwargs: Additional keyword arguments to pass to :func:`torch.Tensor.to`. |
20
|
|
|
|
21
|
|
|
Example: |
22
|
|
|
>>> import torchio as tio |
23
|
|
|
>>> ct = tio.datasets.Slicer('CTChest').CT_chest |
24
|
|
|
>>> clamp = tio.Clamp(out_min=-1000, out_max=1000) |
25
|
|
|
>>> ct_clamped = clamp(ct) |
26
|
|
|
>>> rescale = tio.RescaleIntensity(in_min_max=(-1000, 1000), out_min_max=(0, 255)) |
27
|
|
|
>>> ct_rescaled = rescale(ct_clamped) |
28
|
|
|
>>> to_uint8 = tio.To(torch.uint8) |
29
|
|
|
>>> ct_uint8 = to_uint8(ct_rescaled) |
30
|
|
|
""" |
31
|
|
|
|
32
|
|
|
def __init__( |
33
|
|
|
self, |
34
|
|
|
target: str | torch.dtype | torch.device, |
35
|
|
|
to_kwargs: dict[str, Any] | None = None, |
36
|
|
|
**kwargs, |
37
|
|
|
): |
38
|
|
|
super().__init__(**kwargs) |
39
|
|
|
self.target = target |
40
|
|
|
if to_kwargs is None: |
41
|
|
|
to_kwargs = {} |
42
|
|
|
self.to_kwargs = to_kwargs |
43
|
|
|
self.args_names = ['target', 'to_kwargs'] |
44
|
|
|
|
45
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
46
|
|
|
for image in self.get_images(subject): |
47
|
|
|
assert isinstance(image, ScalarImage) |
48
|
|
|
image.set_data(image.data.to(self.target, **self.to_kwargs)) |
49
|
|
|
return subject |
50
|
|
|
|