Completed
Pull Request — develop (#33)
by
unknown
01:19
created

Rayleigh._do_interp()   A

Complexity

Conditions 1

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 1
c 5
b 0
f 0
dl 0
loc 6
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
# Copyright (c) 2016-2018 Pytroll
5
6
# Author(s):
7
8
#   Adam.Dybbroe <[email protected]>
9
#   Martin Raspaud <[email protected]>
10
11
# This program is free software: you can redistribute it and/or modify
12
# it under the terms of the GNU General Public License as published by
13
# the Free Software Foundation, either version 3 of the License, or
14
# (at your option) any later version.
15
16
# This program is distributed in the hope that it will be useful,
17
# but WITHOUT ANY WARRANTY; without even the implied warranty of
18
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19
# GNU General Public License for more details.
20
21
# You should have received a copy of the GNU General Public License
22
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
23
24
"""Atmospheric correction of shortwave imager bands in the wavelength range 400
25
to 800 nm
26
27
"""
28
29
import os
30
import time
31
import logging
32
from six import integer_types
33
34
import h5py
35
import numpy as np
36
37
try:
38
    from dask.array import (where, zeros, clip, rad2deg, deg2rad, cos, arccos,
39
                            atleast_2d, Array, map_blocks, from_array)
40
    HAVE_DASK = True
41
    # use serializable h5py wrapper to make sure files are closed properly
42
    import h5pickle as h5py
43
except ImportError:
44
    from numpy import where, zeros, clip, rad2deg, deg2rad, cos, arccos, atleast_2d
45
    map_blocks = None
46
    from_array = None
47
    Array = None
48
    HAVE_DASK = False
49
50
from geotiepoints.multilinear import MultilinearInterpolator
51
from pyspectral.rsr_reader import RelativeSpectralResponse
52
from pyspectral.utils import (INSTRUMENTS, RAYLEIGH_LUT_DIRS,
53
                              AEROSOL_TYPES, ATMOSPHERES,
54
                              BANDNAMES,
55
                              download_luts, get_central_wave,
56
                              get_bandname_from_wavelength)
57
58
LOG = logging.getLogger(__name__)
59
60
61
class BandFrequencyOutOfRange(ValueError):
62
63
    """Exception when the band frequency is out of the visible range"""
64
65
    pass
66
67
68
class Rayleigh(object):
69
70
    """Container for the atmospheric correction of satellite imager bands.
71
72
    This class removes background contributions of Rayleigh scattering of
73
    molecules and Mie scattering and absorption by aerosols.
74
75
    """
76
77
    def __init__(self, platform_name, sensor, **kwargs):
78
        """Initialize class and determine LUT to use."""
79
        self.platform_name = platform_name
80
        self.sensor = sensor
81
        self.coeff_filename = None
82
83
        atm_type = kwargs.get('atmosphere', 'us-standard')
84
        if atm_type not in ATMOSPHERES:
85
            raise AttributeError('Atmosphere type not supported! ' +
86
                                 'Need to be one of {}'.format(str(ATMOSPHERES)))
87
88
        aerosol_type = kwargs.get('aerosol_type', 'marine_clean_aerosol')
89
90
        if aerosol_type not in AEROSOL_TYPES:
91
            raise AttributeError('Aerosol type not supported! ' +
92
                                 'Need to be one of {0}'.format(str(AEROSOL_TYPES)))
93
94
        rayleigh_dir = RAYLEIGH_LUT_DIRS[aerosol_type]
95
96
        if atm_type not in ATMOSPHERES.keys():
97
            LOG.error("Atmosphere type %s not supported", atm_type)
98
99
        LOG.info("Atmosphere chosen: %s", atm_type)
100
101
        # Try fix instrument naming
102
        instr = INSTRUMENTS.get(platform_name, sensor)
103
        if instr != sensor:
104
            sensor = instr
105
            LOG.warning("Inconsistent sensor/satellite input - " +
106
                        "sensor set to %s", sensor)
107
108
        self.sensor = sensor.replace('/', '')
109
110
        ext = atm_type.replace(' ', '_')
111
        lutname = "rayleigh_lut_{0}.h5".format(ext)
112
        self.reflectance_lut_filename = os.path.join(rayleigh_dir, lutname)
113
        if not os.path.exists(self.reflectance_lut_filename):
114
            LOG.warning(
115
                "No lut file %s on disk", self.reflectance_lut_filename)
116
            LOG.info("Will download from internet...")
117
            download_luts(aerosol_type=aerosol_type)
118
119
        if (not os.path.exists(self.reflectance_lut_filename) or
120
                not os.path.isfile(self.reflectance_lut_filename)):
121
            raise IOError('pyspectral file for Rayleigh scattering correction ' +
122
                          'does not exist! Filename = ' +
123
                          str(self.reflectance_lut_filename))
124
125
        LOG.debug('LUT filename: %s', str(self.reflectance_lut_filename))
126
        self._rayl = None
127
        self._wvl_coord = None
128
        self._azid_coord = None
129
        self._satz_sec_coord = None
130
        self._sunz_sec_coord = None
131
132
    def get_effective_wavelength(self, bandname):
133
        """Get the effective wavelength with Rayleigh scattering in mind"""
134
        try:
135
            rsr = RelativeSpectralResponse(self.platform_name, self.sensor)
136
        except IOError:
137
            LOG.exception(
138
                "No spectral responses for this platform and sensor: %s %s", self.platform_name, self.sensor)
139
            if isinstance(bandname, (float, integer_types)):
140
                LOG.warning(
141
                    "Effective wavelength is set to the requested band wavelength = %f", bandname)
142
                return bandname
143
            raise
144
145
        if isinstance(bandname, str):
146
            bandname = BANDNAMES.get(self.sensor, BANDNAMES['generic']).get(bandname, bandname)
147
        elif isinstance(bandname, (float, integer_types)):
148
            if not(0.4 < bandname < 0.8):
149
                raise BandFrequencyOutOfRange(
150
                    'Requested band frequency should be between 0.4 and 0.8 microns!')
151
            bandname = get_bandname_from_wavelength(self.sensor, bandname, rsr.rsr)
152
153
        wvl, resp = rsr.rsr[bandname][
154
            'det-1']['wavelength'], rsr.rsr[bandname]['det-1']['response']
155
156
        cwvl = get_central_wave(wvl, resp, weight=1. / wvl**4)
157
        LOG.debug("Band name: %s  Effective wavelength: %f", bandname, cwvl)
158
159
        return cwvl
160
161
    def get_reflectance_lut(self):
162
        """Read the LUT with reflectances as a function of wavelength, satellite zenith
163
        secant, azimuth difference angle, and sun zenith secant
164
165
        """
166
        if self._rayl is None:
167
            lut_vars = get_reflectance_lut(self.reflectance_lut_filename)
168
            self._rayl = lut_vars[0]
169
            self._wvl_coord = lut_vars[1]
170
            self._azid_coord = lut_vars[2]
171
            self._satz_sec_coord = lut_vars[3]
172
            self._sunz_sec_coord = lut_vars[4]
173
        return self._rayl, self._wvl_coord, self._azid_coord,\
174
            self._satz_sec_coord, self._sunz_sec_coord
175
176
    def get_reflectance(self, sun_zenith, sat_zenith, azidiff, bandname,
177
                        redband=None):
178
        """Get the reflectance from the three sun-sat angles."""
179
        # Get wavelength in nm for band:
180
        wvl = self.get_effective_wavelength(bandname) * 1000.0
181
        rayl, wvl_coord, azid_coord, satz_sec_coord, sunz_sec_coord = \
182
            self.get_reflectance_lut()
183
184
        # force dask arrays
185
        compute = False
186
        if HAVE_DASK and not isinstance(sun_zenith, Array):
187
            compute = True
188
            sun_zenith = from_array(sun_zenith, chunks=sun_zenith.shape)
189
            sat_zenith = from_array(sat_zenith, chunks=sat_zenith.shape)
190
            azidiff = from_array(azidiff, chunks=azidiff.shape)
191
            if redband is not None:
192
                redband = from_array(redband, chunks=redband.shape)
193
194
        clip_angle = rad2deg(arccos(1. / sunz_sec_coord.max()))
195
        sun_zenith = clip(sun_zenith, 0, clip_angle)
196
        sunzsec = 1. / cos(deg2rad(sun_zenith))
197
        clip_angle = rad2deg(arccos(1. / satz_sec_coord.max()))
198
        sat_zenith = clip(sat_zenith, 0, clip_angle)
199
        satzsec = 1. / cos(deg2rad(sat_zenith))
200
        shape = sun_zenith.shape
201
202
        if not(wvl_coord.min() < wvl < wvl_coord.max()):
203
            LOG.warning(
204
                "Effective wavelength for band %s outside 400-800 nm range!",
205
                str(bandname))
206
            LOG.info(
207
                "Set the rayleigh/aerosol reflectance contribution to zero!")
208
            if HAVE_DASK:
209
                chunks = sun_zenith.chunks if redband is None \
210
                    else redband.chunks
211
                res = zeros(shape, chunks=chunks)
212
                return res.compute() if compute else res
213
            else:
214
                return zeros(shape)
215
216
        idx = np.searchsorted(wvl_coord, wvl)
217
        wvl1 = wvl_coord[idx - 1]
218
        wvl2 = wvl_coord[idx]
219
220
        fac = (wvl2 - wvl) / (wvl2 - wvl1)
221
        raylwvl = fac * rayl[idx - 1, :, :, :] + (1 - fac) * rayl[idx, :, :, :]
222
        tic = time.time()
223
224
        smin = [sunz_sec_coord[0], azid_coord[0], satz_sec_coord[0]]
225
        smax = [sunz_sec_coord[-1], azid_coord[-1], satz_sec_coord[-1]]
226
        orders = [
227
            len(sunz_sec_coord), len(azid_coord), len(satz_sec_coord)]
228
        minterp = MultilinearInterpolator(smin, smax, orders)
229
230
        f_3d_grid = raylwvl
231
        minterp.set_values(atleast_2d(f_3d_grid.ravel()))
232
233
        def _do_interp(minterp, sunzsec, azidiff, satzsec):
234
            interp_points2 = np.vstack((sunzsec.ravel(),
235
                                        180 - azidiff.ravel(),
236
                                        satzsec.ravel()))
237
            res = minterp(interp_points2)
238
            return res.reshape(sunzsec.shape)
239
240
        if HAVE_DASK:
241
            ipn = map_blocks(_do_interp, minterp, sunzsec, azidiff,
242
                             satzsec, dtype=raylwvl.dtype,
243
                             chunks=azidiff.chunks)
244
        else:
245
            ipn = _do_interp(minterp, sunzsec, azidiff, satzsec)
246
247
        LOG.debug("Time - Interpolation: {0:f}".format(time.time() - tic))
248
249
        ipn *= 100
250
        res = ipn
251
        if redband is not None:
252
            res = where(redband < 20., res,
253
                        (1 - (redband - 20) / 80) * res)
254
255
        res = clip(res, 0, 100)
256
        if compute:
257
            res = res.compute()
258
        return res
259
260
261
def get_reflectance_lut(filename):
262
    """Read the LUT with reflectances as a function of wavelength, satellite
263
    zenith secant, azimuth difference angle, and sun zenith secant
264
265
    """
266
267
    h5f = h5py.File(filename, 'r')
268
269
    tab = h5f['reflectance']
270
    wvl = h5f['wavelengths']
271
    azidiff = h5f['azimuth_difference']
272
    satellite_zenith_secant = h5f['satellite_zenith_secant']
273
    sun_zenith_secant = h5f['sun_zenith_secant']
274
275
    if HAVE_DASK:
276
        tab = from_array(tab, chunks=(10, 10, 10, 10))
277
        # wvl_coord is used in a lot of non-dask functions, keep in memory
278
        wvl = from_array(wvl, chunks=(100,)).persist()
279
        azidiff = from_array(azidiff, chunks=(1000,))
280
        satellite_zenith_secant = from_array(satellite_zenith_secant,
281
                                             chunks=(1000,))
282
        sun_zenith_secant = from_array(sun_zenith_secant,
283
                                       chunks=(1000,))
284
    else:
285
        # load all of the data we are going to use in to memory
286
        tab = tab[:]
287
        wvl = wvl[:]
288
        azidiff = azidiff[:]
289
        satellite_zenith_secant = satellite_zenith_secant[:]
290
        sun_zenith_secant = sun_zenith_secant[:]
291
        h5f.close()
292
293
    return tab, wvl, azidiff, satellite_zenith_secant, sun_zenith_secant
294
295
296
# if __name__ == "__main__":
297
#     SHAPE = (1000, 3000)
298
#     NDIM = SHAPE[0] * SHAPE[1]
299
#     SUNZ = np.ma.arange(
300
#         NDIM / 2, NDIM + NDIM / 2).reshape(SHAPE) * 60. / float(NDIM)
301
#     SATZ = np.ma.arange(NDIM).reshape(SHAPE) * 60. / float(NDIM)
302
#     AZIDIFF = np.ma.arange(NDIM).reshape(SHAPE) * 179.9 / float(NDIM)
303
#     rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')
304