Passed
Pull Request — master (#420)
by Fernando
01:14
created

torchio.transforms.preprocessing.label.one_hot   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 31
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 31
rs 10
c 0
b 0
f 0
wmc 4

2 Methods

Rating   Name   Duplication   Size   Complexity  
A OneHot.__init__() 0 8 1
A OneHot.apply_transform() 0 8 3
1
import torch.nn.functional as F  # noqa: N812
2
3
from .label_transform import LabelTransform
4
5
6
class OneHot(LabelTransform):
7
    r"""Reencode label maps using one-hot encoding.
8
9
    Args:
10
        num_classes: See :func:`~torch.nn.functional.one_hot`.
11
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
12
            keyword arguments.
13
    """
14
    def __init__(
15
            self,
16
            num_classes: int = -1,
17
            **kwargs
18
            ):
19
        super().__init__(**kwargs)
20
        self.num_classes = num_classes
21
        self.args_names = []
22
23
    def apply_transform(self, subject):
24
        for image in self.get_images(subject):
25
            assert image.data.ndim == 4 and image.data.shape[0] == 1
26
            data = image.data.squeeze()
27
            num_classes = -1 if self.num_classes is None else self.num_classes
28
            one_hot = F.one_hot(data.long(), num_classes=num_classes)
29
            image.set_data(one_hot.permute(3, 0, 1, 2).type(data.type()))
30
        return subject
31