Completed
Pull Request — develop (#33)
by
unknown
34s
created

Rayleigh.get_reflectance_lut()   A

Complexity

Conditions 2

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 2
dl 0
loc 14
rs 9.4285
c 2
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
# Copyright (c) 2016-2018 Pytroll
5
6
# Author(s):
7
8
#   Adam.Dybbroe <[email protected]>
9
#   Martin Raspaud <[email protected]>
10
11
# This program is free software: you can redistribute it and/or modify
12
# it under the terms of the GNU General Public License as published by
13
# the Free Software Foundation, either version 3 of the License, or
14
# (at your option) any later version.
15
16
# This program is distributed in the hope that it will be useful,
17
# but WITHOUT ANY WARRANTY; without even the implied warranty of
18
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19
# GNU General Public License for more details.
20
21
# You should have received a copy of the GNU General Public License
22
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
23
24
"""Atmospheric correction of shortwave imager bands in the wavelength range 400
25
to 800 nm
26
27
"""
28
29
import os
30
import time
31
import logging
32
from six import integer_types
33
34
import h5py
35
import numpy as np
36
37
try:
38
    from dask.array import (where, zeros, clip, rad2deg, deg2rad, cos, arccos,
39
                            atleast_2d, Array, map_blocks, from_array)
40
    HAVE_DASK = True
41
    try:
42
        # use serializable h5py wrapper to make sure files are closed properly
43
        import h5pickle as h5py
44
    except ImportError:
45
        pass
46
except ImportError:
47
    from numpy import where, zeros, clip, rad2deg, deg2rad, cos, arccos, atleast_2d
48
    map_blocks = None
49
    from_array = None
50
    Array = None
51
    HAVE_DASK = False
52
53
from geotiepoints.multilinear import MultilinearInterpolator
54
from pyspectral.rsr_reader import RelativeSpectralResponse
55
from pyspectral.utils import (INSTRUMENTS, RAYLEIGH_LUT_DIRS,
56
                              AEROSOL_TYPES, ATMOSPHERES,
57
                              BANDNAMES,
58
                              download_luts, get_central_wave,
59
                              get_bandname_from_wavelength)
60
61
LOG = logging.getLogger(__name__)
62
63
64
class BandFrequencyOutOfRange(ValueError):
65
66
    """Exception when the band frequency is out of the visible range"""
67
68
    pass
69
70
71
class Rayleigh(object):
72
73
    """Container for the atmospheric correction of satellite imager bands.
74
75
    This class removes background contributions of Rayleigh scattering of
76
    molecules and Mie scattering and absorption by aerosols.
77
78
    """
79
80
    def __init__(self, platform_name, sensor, **kwargs):
81
        """Initialize class and determine LUT to use."""
82
        self.platform_name = platform_name
83
        self.sensor = sensor
84
        self.coeff_filename = None
85
86
        atm_type = kwargs.get('atmosphere', 'us-standard')
87
        if atm_type not in ATMOSPHERES:
88
            raise AttributeError('Atmosphere type not supported! ' +
89
                                 'Need to be one of {}'.format(str(ATMOSPHERES)))
90
91
        aerosol_type = kwargs.get('aerosol_type', 'marine_clean_aerosol')
92
93
        if aerosol_type not in AEROSOL_TYPES:
94
            raise AttributeError('Aerosol type not supported! ' +
95
                                 'Need to be one of {0}'.format(str(AEROSOL_TYPES)))
96
97
        rayleigh_dir = RAYLEIGH_LUT_DIRS[aerosol_type]
98
99
        if atm_type not in ATMOSPHERES.keys():
100
            LOG.error("Atmosphere type %s not supported", atm_type)
101
102
        LOG.info("Atmosphere chosen: %s", atm_type)
103
104
        # Try fix instrument naming
105
        instr = INSTRUMENTS.get(platform_name, sensor)
106
        if instr != sensor:
107
            sensor = instr
108
            LOG.warning("Inconsistent sensor/satellite input - " +
109
                        "sensor set to %s", sensor)
110
111
        self.sensor = sensor.replace('/', '')
112
113
        ext = atm_type.replace(' ', '_')
114
        lutname = "rayleigh_lut_{0}.h5".format(ext)
115
        self.reflectance_lut_filename = os.path.join(rayleigh_dir, lutname)
116
        if not os.path.exists(self.reflectance_lut_filename):
117
            LOG.warning(
118
                "No lut file %s on disk", self.reflectance_lut_filename)
119
            LOG.info("Will download from internet...")
120
            download_luts(aerosol_type=aerosol_type)
121
122
        if (not os.path.exists(self.reflectance_lut_filename) or
123
                not os.path.isfile(self.reflectance_lut_filename)):
124
            raise IOError('pyspectral file for Rayleigh scattering correction ' +
125
                          'does not exist! Filename = ' +
126
                          str(self.reflectance_lut_filename))
127
128
        LOG.debug('LUT filename: %s', str(self.reflectance_lut_filename))
129
        self._rayl = None
130
        self._wvl_coord = None
131
        self._azid_coord = None
132
        self._satz_sec_coord = None
133
        self._sunz_sec_coord = None
134
135
    def get_effective_wavelength(self, bandname):
136
        """Get the effective wavelength with Rayleigh scattering in mind"""
137
        try:
138
            rsr = RelativeSpectralResponse(self.platform_name, self.sensor)
139
        except IOError:
140
            LOG.exception(
141
                "No spectral responses for this platform and sensor: %s %s", self.platform_name, self.sensor)
142
            if isinstance(bandname, (float, integer_types)):
143
                LOG.warning(
144
                    "Effective wavelength is set to the requested band wavelength = %f", bandname)
145
                return bandname
146
            raise
147
148
        if isinstance(bandname, str):
149
            bandname = BANDNAMES.get(self.sensor, BANDNAMES['generic']).get(bandname, bandname)
150
        elif isinstance(bandname, (float, integer_types)):
151
            if not(0.4 < bandname < 0.8):
152
                raise BandFrequencyOutOfRange(
153
                    'Requested band frequency should be between 0.4 and 0.8 microns!')
154
            bandname = get_bandname_from_wavelength(self.sensor, bandname, rsr.rsr)
155
156
        wvl, resp = rsr.rsr[bandname][
157
            'det-1']['wavelength'], rsr.rsr[bandname]['det-1']['response']
158
159
        cwvl = get_central_wave(wvl, resp, weight=1. / wvl**4)
160
        LOG.debug("Band name: %s  Effective wavelength: %f", bandname, cwvl)
161
162
        return cwvl
163
164
    def get_reflectance_lut(self):
165
        """Read the LUT with reflectances as a function of wavelength, satellite zenith
166
        secant, azimuth difference angle, and sun zenith secant
167
168
        """
169
        if self._rayl is None:
170
            lut_vars = get_reflectance_lut(self.reflectance_lut_filename)
171
            self._rayl = lut_vars[0]
172
            self._wvl_coord = lut_vars[1]
173
            self._azid_coord = lut_vars[2]
174
            self._satz_sec_coord = lut_vars[3]
175
            self._sunz_sec_coord = lut_vars[4]
176
        return self._rayl, self._wvl_coord, self._azid_coord,\
177
            self._satz_sec_coord, self._sunz_sec_coord
178
179
    def get_reflectance(self, sun_zenith, sat_zenith, azidiff, bandname,
180
                        redband=None):
181
        """Get the reflectance from the three sun-sat angles."""
182
        # Get wavelength in nm for band:
183
        wvl = self.get_effective_wavelength(bandname) * 1000.0
184
        rayl, wvl_coord, azid_coord, satz_sec_coord, sunz_sec_coord = \
185
            self.get_reflectance_lut()
186
187
        # force dask arrays
188
        compute = False
189
        if HAVE_DASK and not isinstance(sun_zenith, Array):
190
            compute = True
191
            sun_zenith = from_array(sun_zenith, chunks=sun_zenith.shape)
192
            sat_zenith = from_array(sat_zenith, chunks=sat_zenith.shape)
193
            azidiff = from_array(azidiff, chunks=azidiff.shape)
194
            if redband is not None:
195
                redband = from_array(redband, chunks=redband.shape)
196
197
        clip_angle = rad2deg(arccos(1. / sunz_sec_coord.max()))
198
        sun_zenith = clip(sun_zenith, 0, clip_angle)
199
        sunzsec = 1. / cos(deg2rad(sun_zenith))
200
        clip_angle = rad2deg(arccos(1. / satz_sec_coord.max()))
201
        sat_zenith = clip(sat_zenith, 0, clip_angle)
202
        satzsec = 1. / cos(deg2rad(sat_zenith))
203
        shape = sun_zenith.shape
204
205
        if not(wvl_coord.min() < wvl < wvl_coord.max()):
206
            LOG.warning(
207
                "Effective wavelength for band %s outside 400-800 nm range!",
208
                str(bandname))
209
            LOG.info(
210
                "Set the rayleigh/aerosol reflectance contribution to zero!")
211
            if HAVE_DASK:
212
                chunks = sun_zenith.chunks if redband is None \
213
                    else redband.chunks
214
                res = zeros(shape, chunks=chunks)
215
                return res.compute() if compute else res
216
            else:
217
                return zeros(shape)
218
219
        idx = np.searchsorted(wvl_coord, wvl)
220
        wvl1 = wvl_coord[idx - 1]
221
        wvl2 = wvl_coord[idx]
222
223
        fac = (wvl2 - wvl) / (wvl2 - wvl1)
224
        raylwvl = fac * rayl[idx - 1, :, :, :] + (1 - fac) * rayl[idx, :, :, :]
225
        tic = time.time()
226
227
        smin = [sunz_sec_coord[0], azid_coord[0], satz_sec_coord[0]]
228
        smax = [sunz_sec_coord[-1], azid_coord[-1], satz_sec_coord[-1]]
229
        orders = [
230
            len(sunz_sec_coord), len(azid_coord), len(satz_sec_coord)]
231
        minterp = MultilinearInterpolator(smin, smax, orders)
232
233
        f_3d_grid = raylwvl
234
        minterp.set_values(atleast_2d(f_3d_grid.ravel()))
235
236
        def _do_interp(minterp, sunzsec, azidiff, satzsec):
237
            interp_points2 = np.vstack((sunzsec.ravel(),
238
                                        180 - azidiff.ravel(),
239
                                        satzsec.ravel()))
240
            res = minterp(interp_points2)
241
            return res.reshape(sunzsec.shape)
242
243
        if HAVE_DASK:
244
            ipn = map_blocks(_do_interp, minterp, sunzsec, azidiff,
245
                             satzsec, dtype=raylwvl.dtype,
246
                             chunks=azidiff.chunks)
247
        else:
248
            ipn = _do_interp(minterp, sunzsec, azidiff, satzsec)
249
250
        LOG.debug("Time - Interpolation: {0:f}".format(time.time() - tic))
251
252
        ipn *= 100
253
        res = ipn
254
        if redband is not None:
255
            res = where(redband < 20., res,
256
                        (1 - (redband - 20) / 80) * res)
257
258
        res = clip(res, 0, 100)
259
        if compute:
260
            res = res.compute()
261
        return res
262
263
264
def get_reflectance_lut(filename):
265
    """Read the LUT with reflectances as a function of wavelength, satellite
266
    zenith secant, azimuth difference angle, and sun zenith secant
267
268
    """
269
270
    h5f = h5py.File(filename, 'r')
271
272
    tab = h5f['reflectance']
273
    wvl = h5f['wavelengths']
274
    azidiff = h5f['azimuth_difference']
275
    satellite_zenith_secant = h5f['satellite_zenith_secant']
276
    sun_zenith_secant = h5f['sun_zenith_secant']
277
278
    if HAVE_DASK:
279
        tab = from_array(tab, chunks=(10, 10, 10, 10))
280
        # wvl_coord is used in a lot of non-dask functions, keep in memory
281
        wvl = from_array(wvl, chunks=(100,)).persist()
282
        azidiff = from_array(azidiff, chunks=(1000,))
283
        satellite_zenith_secant = from_array(satellite_zenith_secant,
284
                                             chunks=(1000,))
285
        sun_zenith_secant = from_array(sun_zenith_secant,
286
                                       chunks=(1000,))
287
    else:
288
        # load all of the data we are going to use in to memory
289
        tab = tab[:]
290
        wvl = wvl[:]
291
        azidiff = azidiff[:]
292
        satellite_zenith_secant = satellite_zenith_secant[:]
293
        sun_zenith_secant = sun_zenith_secant[:]
294
        h5f.close()
295
296
    return tab, wvl, azidiff, satellite_zenith_secant, sun_zenith_secant
297
298
299
# if __name__ == "__main__":
300
#     SHAPE = (1000, 3000)
301
#     NDIM = SHAPE[0] * SHAPE[1]
302
#     SUNZ = np.ma.arange(
303
#         NDIM / 2, NDIM + NDIM / 2).reshape(SHAPE) * 60. / float(NDIM)
304
#     SATZ = np.ma.arange(NDIM).reshape(SHAPE) * 60. / float(NDIM)
305
#     AZIDIFF = np.ma.arange(NDIM).reshape(SHAPE) * 179.9 / float(NDIM)
306
#     rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')
307