Test Failed
Pull Request — master (#708)
by Daniil
03:14
created

fluo_fitters.base_fluo_fitter.lorentzian()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 4
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: base_fluo_fitter
17
   :platform: Unix
18
   :synopsis: a base fitting plugin
19
20
.. moduleauthor:: Aaron Parsons <[email protected]>
21
22
"""
23
24
import logging
25
from savu.plugins.plugin import Plugin
26
from savu.plugins.driver.cpu_plugin import CpuPlugin
27
import peakutils as pe
28
import numpy as np
29
import xraylib as xl
30
from flupy.algorithms.xrf_calculations.transitions_and_shells import \
31
    shells, transitions
32
from flupy.algorithms.xrf_calculations.escape import *
33
from flupy.xrf_data_handling import XRFDataset
34
from copy import deepcopy
35
36
37
class BaseFluoFitter(Plugin, CpuPlugin):
38
    """
39
    This plugin fits peaks. Either XRD or XRF for now.
40
    :param in_datasets: Create a list of the dataset(s). Default: [].
41
    :param out_datasets: A. Default: ["FitWeights", "FitWidths", "FitAreas", "residuals"].
42
    :param width_guess: An initial guess at the width. Default: 0.02.
43
    :param mono_energy: the mono energy. Default: 18.0.
44
    :param peak_shape: Which shape do you want. Default: "gaussian".
45
    :param pileup_cutoff_keV: The cut off. Default: 5.5.
46
    :param include_pileup: Include pileup. Default: 1.
47
    :param include_escape: Include escape. Default: 1.
48
    :param fitted_energy_range_keV: The fitted energy range. Default: [2.,18.].
49
    :param elements: The fitted elements. Default: ['Zn','Cu', 'Ar'].
50
    """
51
52
    def __init__(self, name="BaseFluoFitter"):
53
        super(BaseFluoFitter, self).__init__(name)
54
55
    def base_pre_process(self):
56
        in_meta_data = self.get_in_meta_data()[0]
57
        try:
58
            _foo = in_meta_data.get("PeakIndex")[0]
59
            logging.debug('Using the positions in the peak index')
60
        except KeyError:
61
            logging.debug("No Peak Index in the metadata")
62
            logging.debug("Calculating the positions from energy")
63
#             idx = self.setPositions(in_meta_data)
64
            logging.debug("The index is"+str(self.idx))
65
            in_meta_data.set('PeakIndex', self.idx)
66
            in_meta_data.set('PeakEnergy', self.axis[self.idx])
67
68
    def setup(self):
69
        # set up the output datasets that are created by the plugin
70
        logging.debug('setting up the fluorescence fitting')
71
        in_dataset, out_datasets = self.get_datasets()
72
        in_pData, out_pData = self.get_plugin_datasets()
73
        in_meta_data = in_dataset[0].meta_data
74
75
        shape = in_dataset[0].get_shape()
76
        in_pData[0].plugin_data_setup('SPECTRUM', self.get_max_frames())
77
78
        axis_labels = ['-1.PeakIndex.pixel.unit']
79
        pattern_list = ['SINOGRAM', 'PROJECTION']
80
81
        fitAreas = out_datasets[0]
82
        fitWidths = out_datasets[1]
83
        fitHeights = out_datasets[2]
84
        self.length = shape[-1]
85
        idx = self.setPositions(in_meta_data)
86
        logging.debug("in the setup the index is"+str(idx))
87
        numpeaks = len(idx)
88
        new_shape = shape[:-1] + (numpeaks,)
89
90
        channel = {'core_dims': (-1,), 'slice_dims': list(range(len(shape)-1))}
91
        fitAreas.create_dataset(patterns={in_dataset[0]: pattern_list},
92
                                axis_labels={in_dataset[0]: axis_labels},
93
                                shape=new_shape)
94
        fitAreas.add_pattern("CHANNEL", **channel)
95
        out_pData[0].plugin_data_setup('CHANNEL', self.get_max_frames())
96
97
        fitWidths.create_dataset(patterns={in_dataset[0]: pattern_list},
98
                                 axis_labels={in_dataset[0]: axis_labels},
99
                                 shape=new_shape)
100
        fitWidths.add_pattern("CHANNEL", **channel)
101
        out_pData[1].plugin_data_setup('CHANNEL', self.get_max_frames())
102
103
        fitHeights.create_dataset(patterns={in_dataset[0]: pattern_list},
104
                                  axis_labels={in_dataset[0]: axis_labels},
105
                                  shape=new_shape)
106
        fitHeights.add_pattern("CHANNEL", **channel)
107
        out_pData[2].plugin_data_setup('CHANNEL', self.get_max_frames())
108
109
        residuals = out_datasets[3]
110
        residuals.create_dataset(in_dataset[0])
111
        residuals.set_shape(shape[:-1]+(len(self.axis),))
112
        out_pData[3].plugin_data_setup('SPECTRUM', self.get_max_frames())
113
114
        for i in range(len(out_datasets)):
115
            out_datasets[i].meta_data = deepcopy(in_meta_data)
116
            mData = out_datasets[i].meta_data
117
            mData.set("PeakEnergy", self.axis[self.idx])
118
            mData.set('PeakIndex', self.idx)
119
120
    def setPositions(self, in_meta_data):
121
        paramdict = XRFDataset().paramdict
122
        paramdict["FitParams"]["pileup_cutoff_keV"] = \
123
            self.parameters["pileup_cutoff_keV"]
124
        paramdict["FitParams"]["include_pileup"] = \
125
            self.parameters["include_pileup"]
126
        paramdict["FitParams"]["include_escape"] = \
127
            self.parameters["include_escape"]
128
        paramdict["FitParams"]["fitted_energy_range_keV"] = \
129
            self.parameters["fitted_energy_range_keV"]
130
        if self.parameters['mono_energy'] is None:
131
            paramdict["Experiment"]["incident_energy_keV"] = \
132
                in_meta_data.get("mono_energy")
133
        else:
134
            paramdict["Experiment"]["incident_energy_keV"] = \
135
                self.parameters['mono_energy']
136
        paramdict["Experiment"]["elements"] = \
137
            self.parameters["elements"]
138
        engy = self.findLines(paramdict)
139
        # make it an index since this is what find peak will also give us
140
#         print 'basefluo meta is:'+str(in_meta_data.get_dictionary().keys())
141
        axis = self.axis = in_meta_data.get("energy")
142
        dq = axis[1]-axis[0]
143
        logging.debug("the peak energies are:"+str(engy))
144
        logging.debug("the offset is"+str(axis[0]))
145
        self.idx = np.round((engy-axis[0])/dq).astype(int)
146
147
        return self.idx
148
149
    def findLines(self, paramdict=XRFDataset().paramdict):
150
        """
151
        Calculates the line energies to fit
152
        """
153
        # Incident Energy  used in the experiment
154
        # Energy range to use for fitting
155
        pileup_cut_off = paramdict["FitParams"]["pileup_cutoff_keV"]
156
        include_pileup = paramdict["FitParams"]["include_pileup"]
157
        include_escape = paramdict["FitParams"]["include_escape"]
158
        fitting_range = paramdict["FitParams"]["fitted_energy_range_keV"]
159
#         x = paramdict["FitParams"]["mca_energies_used"]
160
        energy = paramdict["Experiment"]["incident_energy_keV"]
161
        detectortype = 'Vortex_SDD_Xspress'
162
        fitelements = paramdict["Experiment"]["elements"]
163
        peakpos = []
164
        escape_peaks = []
165
        for _j, el in enumerate(fitelements):
166
            z = xl.SymbolToAtomicNumber(str(el))
167
            for i, shell in enumerate(shells):
168
                if(xl.EdgeEnergy(z, shell) < energy - 0.5):
169
                    linepos = 0.0
170
                    count = 0.0
171
                    for line in transitions[i]:
172
                        en = xl.LineEnergy(z, line)
173
                        if(en > 0.0):
174
                            linepos += en
175
                            count += 1.0
176
                    if(count == 0.0):
177
                        break
178
                    linepos = linepos // count
179
                    if(linepos > fitting_range[0] and
180
                            linepos < fitting_range[1]):
181
                        peakpos.append(linepos)
182
        peakpos = np.array(peakpos)
183
        too_low = set(list(peakpos[peakpos > fitting_range[0]]))
184
        too_high = set(list(peakpos[peakpos < fitting_range[1]]))
185
        bar = list(too_low and too_high)
186
        bar = np.unique(bar)
187
        peakpos = list(bar)
188
        peaks = []
189
        peaks.extend(peakpos)
190
        if(include_escape):
191
            for i in range(len(peakpos)):
192
                escape_energy = calc_escape_energy(peakpos[i], detectortype)[0]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable calc_escape_energy does not seem to be defined.
Loading history...
193
                if (escape_energy > fitting_range[0]):
194
                    if (escape_energy < fitting_range[1]):
195
                        escape_peaks.extend([escape_energy])
196
    #         print escape_peaks
197
            peaks.extend(escape_peaks)
198
199
        if(include_pileup):  # applies just to the fluorescence lines
200
            pileup_peaks = []
201
            peakpos1 = np.array(peakpos)
202
            peakpos_high = peakpos1[peakpos1 > pileup_cut_off]
203
            peakpos_high = list(peakpos_high)
204
            for i in range(len(peakpos_high)):
205
                foo = [peakpos_high[i] + x for x in peakpos_high[i:]]
206
                foo = np.array(foo)
207
                pileup_peaks.extend(foo)
208
            pileup_peaks = np.unique(sorted(pileup_peaks))
209
            peaks.extend(pileup_peaks)
210
        peakpos = peaks
211
        peakpos = np.array(peakpos)
212
        too_low = set(list(peakpos[peakpos > fitting_range[0]]))
213
        too_high = set(list(peakpos[peakpos < fitting_range[1] - 0.5]))
214
        bar = list(too_low and too_high)
215
        bar = np.unique(bar)
216
        peakpos = list(bar)
217
        peakpos = np.unique(peakpos)
218
#         print peakpos
219
        return peakpos
220
221 View Code Duplication
    def getAreas(self, fun, x, positions, fitmatrix):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
222
        rest = fitmatrix
223
        numargsinp = self.getFitFunctionNumArgs(str(fun.__name__))  # 2 in
224
        npts = len(fitmatrix) // numargsinp
225
        weights = rest[:npts]
226
        widths = rest[npts:2*npts]
227
        areas = []
228
        for ii in range(len(weights)):
229
            areas.append(np.sum(fun(weights[ii],
230
                                    widths[ii],
231
                                    x,
232
                                    positions[ii],
233
                                    )))
234
        return weights, widths, np.array(areas)
235
236
    def getFitFunctionNumArgs(self,key):
237
        self.lookup = {
238
                       "lorentzian": 2,
239
                       "gaussian": 2
240
                       }
241
        return self.lookup[key]
242
243
    def get_max_frames(self):
244
        return 'single'
245
246
    def nOutput_datasets(self):
247
        return 4
248
249
    def getFitFunction(self,key):
250
        self.lookup = {
251
                       "lorentzian": lorentzian,
252
                       "gaussian": gaussian
253
                       }
254
        return self.lookup[key]
255
256
    def _resid(self, p, fun, y, x, pos):
257
        r = y-self._spectrum_sum(fun, x, pos, *p)
258
259
        return r
260
261 View Code Duplication
    def dfunc(self, p, fun, y, x, pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
262
        if fun.__name__ == 'gaussian' or fun.__name__ == 'lorentzian': # took the lorentzian out. Weird
263
            rest = p
264
            npts = len(p) // 2
265
            a = rest[:npts]
266
            sig = rest[npts:2*npts]
267
            mu = pos
268
            if fun.__name__ == 'gaussian':
269
                da = self.spectrum_sum_dfun(fun, 1./a, x, mu, *p)
270
                dsig_mult = np.zeros((npts, len(x)))
271
                for i in range(npts):
272
                    dsig_mult[i] = ((x-mu[i])**2) / sig[i]**3
273
                dsig = self.spectrum_sum_dfun(fun, dsig_mult, x, mu, *p)
274
                op = np.concatenate([-da, -dsig])
275
            elif fun.__name__ == 'lorentzian':
276
                da = self.spectrum_sum_dfun(fun, 1./a, x, mu, *p)
277
                dsig = np.zeros((npts, len(x)))
278
                for i in range(npts):
279
                    nom = 8 * a[i] * sig[i] * (x - mu[i]) ** 2
280
                    denom = (sig[i]**2 + 4.0 * (x - mu[i])**2)**2
281
                    dsig[i] = nom / denom
282
                op = np.concatenate([-da, -dsig])
283
        else:
284
            op = None
285
        return op
0 ignored issues
show
introduced by
The variable op does not seem to be defined for all execution paths.
Loading history...
286
287
    def _spectrum_sum(self, fun, x, positions, *p):
288
        rest = np.abs(p)
289
        npts = len(p) // 2
290
        weights = rest[:npts]
291
        widths = rest[npts:2*npts]
292
        spec = np.zeros((len(x),))
293
        for ii in range(len(weights)):
294
            spec += fun(weights[ii], widths[ii], x, positions[ii])
295
        return spec
296
297 View Code Duplication
    def spectrum_sum_dfun(self, fun, multiplier, x, pos, *p):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
298
        rest = p
299
        npts = len(p) // 2
300
        weights = rest[:npts]
301
        widths = rest[npts:2*npts]
302
        positions = pos
303
    #    print(len(positions))
304
        spec = np.zeros((npts, len(x)))
305
        #print "len x is "+str(len(spec))
306
    #    print len(spec), type(spec)
307
    #    print len(positions), type(positions)
308
    #    print len(weights), type(weights)
309
        for ii in range(len(weights)):
310
            spec[ii] = multiplier[ii]*fun(weights[ii],
311
                                          widths[ii],
312
                                          x, positions[ii])
313
        return spec
314
315
def lorentzian(a, w, x, c):
316
    y = a / (1.0 + (2.0 * (c - x) / w) ** 2)
317
    return y
318
319
320
def gaussian(a, w, x, c):
321
    return pe.gaussian(x, a, c, w)
322
323
324