| Total Complexity | 3 |
| Total Lines | 32 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | from .weighted import WeightedSampler |
||
| 2 | |||
| 3 | |||
| 4 | class LabelSampler(WeightedSampler): |
||
| 5 | r"""Extract random patches containing labeled voxels. |
||
| 6 | |||
| 7 | This iterable dataset yields patches whose center value is greater than 0 |
||
| 8 | in the :py:attr:`label_name`. |
||
| 9 | |||
| 10 | Args: |
||
| 11 | patch_size: Tuple of integers :math:`(d, h, w)` to generate patches |
||
| 12 | of size :math:`d \times h \times w`. |
||
| 13 | If a single number :math:`n` is provided, |
||
| 14 | :math:`d = h = w = n`. |
||
| 15 | label_name: Name of the label image in the sample that will be used to |
||
| 16 | generate the sampling probability map. |
||
| 17 | """ |
||
| 18 | def __init__(self, patch_size, label_name): |
||
| 19 | super().__init__(patch_size, probability_map=label_name) |
||
| 20 | |||
| 21 | def get_probability_map(self, sample): |
||
| 22 | """Return binarized image for sampling.""" |
||
| 23 | if self.probability_map_name in sample: |
||
| 24 | data = sample[self.probability_map_name].data > 0.5 |
||
| 25 | else: |
||
| 26 | message = ( |
||
| 27 | f'Image "{self.probability_map_name}"' |
||
| 28 | f' not found in subject sample: {sample}' |
||
| 29 | ) |
||
| 30 | raise KeyError(message) |
||
| 31 | return data |
||
| 32 |