1
|
|
|
from torchio.data.image import ScalarImage |
2
|
|
|
from typing import Optional |
3
|
|
|
import torch |
4
|
|
|
from ....data.subject import Subject |
5
|
|
|
from ...intensity_transform import IntensityTransform |
6
|
|
|
|
7
|
|
|
|
8
|
|
|
class Clamp(IntensityTransform): |
9
|
|
|
"""Clamp intensity values to a certain range. |
10
|
|
|
Args: |
11
|
|
|
out_min: See :func:`torch.clamp`. |
12
|
|
|
out_max: See :func:`torch.clamp`. |
13
|
|
|
|
14
|
|
|
Example: |
15
|
|
|
>>> import torchio as tio |
16
|
|
|
>>> ct = tio.ScalarImage('ct_scan.nii.gz') |
17
|
|
|
>>> ct_air, ct_bone = -1000, 1000 |
18
|
|
|
>>> clamp = tio.Clamp(out_min=ct_air, out_max=ct_bone) |
19
|
|
|
>>> ct_clamped = clamp(ct) |
20
|
|
|
""" |
21
|
|
|
def __init__( |
22
|
|
|
self, |
23
|
|
|
out_min: Optional[float] = None, |
24
|
|
|
out_max: Optional[float] = None, |
25
|
|
|
**kwargs |
26
|
|
|
): |
27
|
|
|
super().__init__(**kwargs) |
28
|
|
|
self.out_min, self.out_max = out_min, out_max |
29
|
|
|
|
30
|
|
|
def __repr__(self): |
31
|
|
|
string = ( |
32
|
|
|
f'{self.__class__.__name__}' |
33
|
|
|
f'(out_min={self.out_min}, out_max={self.out_max})' |
34
|
|
|
) |
35
|
|
|
return string |
36
|
|
|
|
37
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
38
|
|
|
for image in self.get_images(subject): |
39
|
|
|
self.apply_clamp(image) |
40
|
|
|
return subject |
41
|
|
|
|
42
|
|
|
def apply_clamp(self, image: ScalarImage): |
43
|
|
|
image.set_data(self.clamp(image.data)) |
44
|
|
|
|
45
|
|
|
def clamp(self, tensor: torch.Tensor) -> torch.Tensor: |
46
|
|
|
return tensor.clamp(self.out_min, self.out_max) |
47
|
|
|
|