Completed
Pull Request — develop (#33)
by
unknown
35s
created

Rayleigh._do_interp()   A

Complexity

Conditions 1

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 1
c 5
b 0
f 0
dl 0
loc 6
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
# Copyright (c) 2016-2018 Pytroll
5
6
# Author(s):
7
8
#   Adam.Dybbroe <[email protected]>
9
#   Martin Raspaud <[email protected]>
10
11
# This program is free software: you can redistribute it and/or modify
12
# it under the terms of the GNU General Public License as published by
13
# the Free Software Foundation, either version 3 of the License, or
14
# (at your option) any later version.
15
16
# This program is distributed in the hope that it will be useful,
17
# but WITHOUT ANY WARRANTY; without even the implied warranty of
18
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19
# GNU General Public License for more details.
20
21
# You should have received a copy of the GNU General Public License
22
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
23
24
"""Atmospheric correction of shortwave imager bands in the wavelength range 400
25
to 800 nm
26
27
"""
28
29
import os
30
import time
31
import logging
32
from six import integer_types
33
34
import h5py
35
import numpy as np
36
37
try:
38
    from dask.array import where, zeros, map_blocks, from_array, clip, Array
39
    HAVE_DASK = True
40
except ImportError:
41
    from numpy import where, zeros, clip
42
    map_blocks = None
43
    from_array = None
44
    Array = None
45
    HAVE_DASK = False
46
47
from geotiepoints.multilinear import MultilinearInterpolator
48
from pyspectral.rsr_reader import RelativeSpectralResponse
49
from pyspectral.utils import (INSTRUMENTS, RAYLEIGH_LUT_DIRS,
50
                              AEROSOL_TYPES, ATMOSPHERES,
51
                              BANDNAMES,
52
                              download_luts, get_central_wave,
53
                              get_bandname_from_wavelength)
54
55
LOG = logging.getLogger(__name__)
56
57
58
class BandFrequencyOutOfRange(ValueError):
59
60
    """Exception when the band frequency is out of the visible range"""
61
62
    pass
63
64
65
class Rayleigh(object):
66
67
    """Container for the atmospheric correction of satellite imager bands.
68
69
    This class removes background contributions of Rayleigh scattering of
70
    molecules and Mie scattering and absorption by aerosols.
71
72
    """
73
74
    def __init__(self, platform_name, sensor, **kwargs):
75
        """Initialize class and determine LUT to use."""
76
        self.platform_name = platform_name
77
        self.sensor = sensor
78
        self.coeff_filename = None
79
80
        atm_type = kwargs.get('atmosphere', 'us-standard')
81
        if atm_type not in ATMOSPHERES:
82
            raise AttributeError('Atmosphere type not supported! ' +
83
                                 'Need to be one of {}'.format(str(ATMOSPHERES)))
84
85
        aerosol_type = kwargs.get('aerosol_type', 'marine_clean_aerosol')
86
87
        if aerosol_type not in AEROSOL_TYPES:
88
            raise AttributeError('Aerosol type not supported! ' +
89
                                 'Need to be one of {0}'.format(str(AEROSOL_TYPES)))
90
91
        rayleigh_dir = RAYLEIGH_LUT_DIRS[aerosol_type]
92
93
        if atm_type not in ATMOSPHERES.keys():
94
            LOG.error("Atmosphere type %s not supported", atm_type)
95
96
        LOG.info("Atmosphere chosen: %s", atm_type)
97
98
        # Try fix instrument naming
99
        instr = INSTRUMENTS.get(platform_name, sensor)
100
        if instr != sensor:
101
            sensor = instr
102
            LOG.warning("Inconsistent sensor/satellite input - " +
103
                        "sensor set to %s", sensor)
104
105
        self.sensor = sensor.replace('/', '')
106
107
        ext = atm_type.replace(' ', '_')
108
        lutname = "rayleigh_lut_{0}.h5".format(ext)
109
        self.reflectance_lut_filename = os.path.join(rayleigh_dir, lutname)
110
        if not os.path.exists(self.reflectance_lut_filename):
111
            LOG.warning(
112
                "No lut file %s on disk", self.reflectance_lut_filename)
113
            LOG.info("Will download from internet...")
114
            download_luts(aerosol_type=aerosol_type)
115
116
        if (not os.path.exists(self.reflectance_lut_filename) or
117
                not os.path.isfile(self.reflectance_lut_filename)):
118
            raise IOError('pyspectral file for Rayleigh scattering correction ' +
119
                          'does not exist! Filename = ' +
120
                          str(self.reflectance_lut_filename))
121
122
        LOG.debug('LUT filename: %s', str(self.reflectance_lut_filename))
123
        self._rayl = None
124
        self._wvl_coord = None
125
        self._azid_coord = None
126
        self._satz_sec_coord = None
127
        self._sunz_sec_coord = None
128
129
    def get_effective_wavelength(self, bandname):
130
        """Get the effective wavelength with Rayleigh scattering in mind"""
131
        try:
132
            rsr = RelativeSpectralResponse(self.platform_name, self.sensor)
133
        except IOError:
134
            LOG.exception(
135
                "No spectral responses for this platform and sensor: %s %s", self.platform_name, self.sensor)
136
            if isinstance(bandname, (float, integer_types)):
137
                LOG.warning(
138
                    "Effective wavelength is set to the requested band wavelength = %f", bandname)
139
                return bandname
140
            raise
141
142
        if isinstance(bandname, str):
143
            bandname = BANDNAMES.get(self.sensor, BANDNAMES['generic']).get(bandname, bandname)
144
        elif isinstance(bandname, (float, integer_types)):
145
            if not(0.4 < bandname < 0.8):
146
                raise BandFrequencyOutOfRange(
147
                    'Requested band frequency should be between 0.4 and 0.8 microns!')
148
            bandname = get_bandname_from_wavelength(self.sensor, bandname, rsr.rsr)
149
150
        wvl, resp = rsr.rsr[bandname][
151
            'det-1']['wavelength'], rsr.rsr[bandname]['det-1']['response']
152
153
        cwvl = get_central_wave(wvl, resp, weight=1. / wvl**4)
154
        LOG.debug("Band name: %s  Effective wavelength: %f", bandname, cwvl)
155
156
        return cwvl
157
158
    def get_reflectance_lut(self):
159
        """Read the LUT with reflectances as a function of wavelength, satellite zenith
160
        secant, azimuth difference angle, and sun zenith secant
161
162
        """
163
        if self._rayl is None:
164
            lut_vars = get_reflectance_lut(self.reflectance_lut_filename)
165
            self._rayl = lut_vars[0]
166
            self._wvl_coord = lut_vars[1]
167
            self._azid_coord = lut_vars[2]
168
            self._satz_sec_coord = lut_vars[3]
169
            self._sunz_sec_coord = lut_vars[4]
170
        return self._rayl, self._wvl_coord, self._azid_coord,\
171
            self._satz_sec_coord, self._sunz_sec_coord
172
173
    def get_reflectance(self, sun_zenith, sat_zenith, azidiff, bandname,
174
                        redband=None):
175
        """Get the reflectance from the three sun-sat angles."""
176
        # Get wavelength in nm for band:
177
        wvl = self.get_effective_wavelength(bandname) * 1000.0
178
        rayl, wvl_coord, azid_coord, satz_sec_coord, sunz_sec_coord = \
179
            self.get_reflectance_lut()
180
181
        # force dask arrays
182
        compute = False
183
        if HAVE_DASK and not isinstance(sun_zenith, Array):
184
            compute = True
185
            sun_zenith = from_array(sun_zenith, chunks=sun_zenith.shape)
186
            sat_zenith = from_array(sat_zenith, chunks=sat_zenith.shape)
187
            azidiff = from_array(azidiff, chunks=azidiff.shape)
188
            if redband is not None:
189
                redband = from_array(redband, chunks=redband.shape)
190
191
        clip_angle = np.rad2deg(np.arccos(1. / sunz_sec_coord.max()))
192
        sun_zenith = clip(sun_zenith, 0, clip_angle)
193
        sunzsec = 1. / np.cos(np.deg2rad(sun_zenith))
194
        clip_angle = np.rad2deg(np.arccos(1. / satz_sec_coord.max()))
195
        sat_zenith = clip(sat_zenith, 0, clip_angle)
196
        satzsec = 1. / np.cos(np.deg2rad(sat_zenith))
197
198
        shape = sun_zenith.shape
199
200
        if not(wvl_coord.min() < wvl < wvl_coord.max()):
201
            LOG.warning(
202
                "Effective wavelength for band %s outside 400-800 nm range!",
203
                str(bandname))
204
            LOG.info(
205
                "Set the rayleigh/aerosol reflectance contribution to zero!")
206
            if HAVE_DASK:
207
                chunks = sun_zenith.chunks if redband is None \
208
                    else redband.chunks
209
                res = zeros(shape, chunks=chunks)
210
                return res.compute() if compute else res
211
            else:
212
                return zeros(shape)
213
214
        idx = np.searchsorted(wvl_coord, wvl)
215
        wvl1 = wvl_coord[idx - 1]
216
        wvl2 = wvl_coord[idx]
217
218
        fac = (wvl2 - wvl) / (wvl2 - wvl1)
219
        raylwvl = fac * rayl[idx - 1, :, :, :] + (1 - fac) * rayl[idx, :, :, :]
220
        tic = time.time()
221
222
        smin = [sunz_sec_coord[0], azid_coord[0], satz_sec_coord[0]]
223
        smax = [sunz_sec_coord[-1], azid_coord[-1], satz_sec_coord[-1]]
224
        orders = [
225
            len(sunz_sec_coord), len(azid_coord), len(satz_sec_coord)]
226
        minterp = MultilinearInterpolator(smin, smax, orders)
227
228
        f_3d_grid = raylwvl
229
        minterp.set_values(np.atleast_2d(f_3d_grid.ravel()))
230
231
        def _do_interp(minterp, sunzsec, azidiff, satzsec):
232
            interp_points2 = np.vstack((sunzsec.ravel(),
233
                                        180 - azidiff.ravel(),
234
                                        satzsec.ravel()))
235
            res = minterp(interp_points2)
236
            return res.reshape(sunzsec.shape)
237
238
        if HAVE_DASK:
239
            ipn = map_blocks(_do_interp, minterp, sunzsec, azidiff,
240
                             satzsec, dtype=raylwvl.dtype,
241
                             chunks=azidiff.chunks)
242
        else:
243
            ipn = _do_interp(minterp, sunzsec, azidiff, satzsec)
244
245
        LOG.debug("Time - Interpolation: {0:f}".format(time.time() - tic))
246
247
        ipn *= 100
248
        res = ipn
249
        if redband is not None:
250
            res = where(redband < 20., res,
251
                        (1 - (redband - 20) / 80) * res)
252
253
        res = clip(res, 0, 100)
254
        if compute:
255
            res = res.compute()
256
        return res
257
258
259
def get_reflectance_lut(filename):
260
    """Read the LUT with reflectances as a function of wavelength, satellite
261
    zenith secant, azimuth difference angle, and sun zenith secant
262
263
    """
264
265
    h5f = h5py.File(filename, 'r')
266
267
    tab = h5f['reflectance']
268
    wvl = h5f['wavelengths']
269
    azidiff = h5f['azimuth_difference']
270
    satellite_zenith_secant = h5f['satellite_zenith_secant']
271
    sun_zenith_secant = h5f['sun_zenith_secant']
272
273
    if HAVE_DASK:
274
        tab = from_array(tab, chunks=(10, 10, 10, 10))
275
        # wvl_coord is used in a lot of non-dask functions, keep in memory
276
        wvl = from_array(wvl, chunks=(100,)).persist()
277
        azidiff = from_array(azidiff, chunks=(1000,))
278
        satellite_zenith_secant = from_array(satellite_zenith_secant,
279
                                             chunks=(1000,))
280
        sun_zenith_secant = from_array(sun_zenith_secant,
281
                                       chunks=(1000,))
282
    else:
283
        # load all of the data we are going to use in to memory
284
        tab = tab[:]
285
        wvl = wvl[:]
286
        azidiff = azidiff[:]
287
        satellite_zenith_secant = satellite_zenith_secant[:]
288
        sun_zenith_secant = sun_zenith_secant[:]
289
290
    return tab, wvl, azidiff, satellite_zenith_secant, sun_zenith_secant
291
292
293
# if __name__ == "__main__":
294
#     SHAPE = (1000, 3000)
295
#     NDIM = SHAPE[0] * SHAPE[1]
296
#     SUNZ = np.ma.arange(
297
#         NDIM / 2, NDIM + NDIM / 2).reshape(SHAPE) * 60. / float(NDIM)
298
#     SATZ = np.ma.arange(NDIM).reshape(SHAPE) * 60. / float(NDIM)
299
#     AZIDIFF = np.ma.arange(NDIM).reshape(SHAPE) * 179.9 / float(NDIM)
300
#     rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')
301