Completed
Push — master ( b6035d...10c423 )
by Andy
32s
created

Censors   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 86
Duplicated Lines 0 %

Importance

Changes 2
Bugs 1 Features 0
Metric Value
c 2
b 1
f 0
dl 0
loc 86
rs 10
wmc 14

7 Methods

Rating   Name   Duplication   Size   Complexity  
A setdefault() 0 4 2
B update() 0 11 5
B __setitem__() 0 25 3
A __init__() 0 6 1
A num_pixels() 0 3 1
A __getstate__() 0 6 1
A label_names() 0 3 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Utilities to deal with wavelength censoring.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = ["Censors", "create_mask", "design_matrix_mask"]
12
13
import numpy as np
14
15
from .vectorizer.base import BaseVectorizer
16
17
18
class Censors(dict):
19
20
    """
21
    A dictionary sub-class that allows for label censoring masks to be
22
    applied on a per-pixel basis to CannonModel objects.
23
24
    :param label_names:
25
        A list containing the label names that form the model vectorizer.
26
27
    :param num_pixels:
28
        The number of pixels per star.
29
30
    :param items: [optional]
31
        A dictionary containing label names as keys and masks as values.
32
    """
33
34
    def __init__(self, label_names, num_pixels, items=None, **kwargs):
35
        super(Censors, self).__init__(**kwargs)
36
        self._label_names = tuple(label_names)
37
        self._num_pixels = int(num_pixels)
38
        self.update(items or {})
39
        return None
40
41
42
    def __setitem__(self, label_name, mask):
43
        """
44
        Update an entry in the pixel censoring dictionary.
45
46
        :param label_name:
47
            The name of the label to apply the censoring to.
48
49
        :param mask:
50
            A boolean mask with a size that equals the number of pixels per star.
51
            Note that a mask value of `True` indicates the label is censored at
52
            the given pixel, and therefore that label will not contribute to
53
            the spectral flux at that pixel.
54
        """
55
56
        if label_name not in self.label_names:
57
            raise ValueError(
58
                "unrecognized label name '{}' for censoring".format(label_name))
59
60
        mask = np.array(mask).flatten().astype(bool)
61
        if mask.size != self.num_pixels:
62
            raise ValueError("'{}' censoring mask has wrong size ({} != {})"\
63
                .format(label_name, mask.size, self.num_pixels))
64
65
        dict.__setitem__(self, label_name, mask)
66
        return None
67
68
69
    def update(self, *args, **kwargs):
70
        if args:
71
            if len(args) > 1:
72
                raise TypeError("update expected at most 1 arguments, got {}"\
73
                    .format(len(args)))
74
            other = dict(args[0])
75
            for key in other:
76
                self[key] = other[key]
77
78
        for key in kwargs:
79
            self[key] = kwargs[key]
80
81
82
    def setdefault(self, key, value=None):
83
        if key not in self:
84
            self[key] = value
85
        return self[key]
86
87
88
    def __getstate__(self):
89
        """ Return the state of the censoring mask in a serializable form. """
90
        return dict(
91
            label_names=self.label_names,
92
            num_pixels=self.num_pixels, 
93
            items=dict(self.items()))
94
95
96
    @property
97
    def label_names(self):
98
        return self._label_names
99
100
101
    @property
102
    def num_pixels(self):
103
        return self._num_pixels
104
105
106
def create_mask(dispersion, censored_regions):
107
    """
108
    Return a boolean censoring mask based on a structured list of (start, end)
109
    regions.
110
111
    :param dispersion:
112
        An array of dispersion values.
113
114
    :param censored_regions:
115
        A list of two-length tuples containing the `(start, end)` points of a
116
        censored region.
117
118
    :returns:
119
        A boolean mask indicating whether the pixels in the `dispersion` array
120
        are masked.
121
    """
122
123
    mask = np.zeros(dispersion.size, dtype=bool)
124
125
    if isinstance(censored_regions[0], (int, float)):
126
        censored_regions = [censored_regions]
127
128
    for start, end in censored_regions:
129
        start, end = (start or -np.inf, end or +np.inf)
130
131
        censored = (end >= dispersion) * (dispersion >= start)
132
        mask[censored] = True
133
134
    return mask
135
136
137
def design_matrix_mask(censors, vectorizer):
138
    """
139
    Return a mask of which indices in the design matrix columns should be
140
    used for a given pixel. 
141
142
    :param censors:
143
        A censoring dictionary.
144
145
    :param vectorizer:
146
        The model vectorizer:
147
148
    :returns:
149
        A mask of which indices in the model design matrix should be used for a
150
        given pixel.
151
    """        
152
153
    if not isinstance(censors, Censors):
154
        raise TypeError("censors must be a Censors class")
155
156
    if not isinstance(vectorizer, BaseVectorizer):
157
        raise TypeError("vectorizer must be a sub-class of BaseVectorizer")
158
159
    # Parse all the terms once-off.
160
    mapper = {}
161
    pixel_masks = np.atleast_2d(list(map(list, censors.values())))
162
    for i, terms in enumerate(vectorizer.terms):
163
        for label_index, power in terms:
164
            # Let's map this directly to the censors that we actually have.
165
            try:
166
                censor_index = list(censors.keys()).index(
167
                    censors.label_names[label_index])
168
169
            except ValueError:
170
                # Label name is not censored, so we don't care.
171
                continue
172
173
            else:
174
                # Initialize a list if necessary.
175
                mapper.setdefault(censor_index, [])
176
177
                # Note that we add +1 because the first term in the design
178
                # matrix columns will actually be the pivot point.
179
                mapper[censor_index].append(1 + i)
180
181
    # We already know the number of terms from i.
182
    mask = np.ones((censors.num_pixels, 2 + i), dtype=bool)
183
    for censor_index, pixel in zip(*np.where(pixel_masks)):
184
        mask[pixel, mapper[censor_index]] = False
185
186
    return mask
187