Completed
Push — develop ( 74ca54...4e7c62 )
by Angeline
14s queued 12s
created

test_Apex.TestApexInit.test_set_epoch()   A

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 14
rs 9.9
c 0
b 0
f 0
cc 1
nop 2
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import warnings
20
21
from apexpy import fortranapex as fa
22
from apexpy import Apex, ApexHeightError, helpers
23
24
25
@pytest.fixture()
26
def igrf_file():
27
    """A fixture for handling the coefficient file."""
28
    # Ensure the coefficient file exists
29
    original_file = os.path.join(os.path.dirname(helpers.__file__),
30
                                 'igrf13coeffs.txt')
31
    tmp_file = "temp_coeff.txt"
32
    assert os.path.isfile(original_file)
33
34
    # Move the coefficient file
35
    os.rename(original_file, tmp_file)
36
    yield original_file
37
38
    # Move the coefficient file back
39
    os.rename(tmp_file, original_file)
40
    return
41
42
43
def test_set_epoch_file_error(igrf_file):
44
    """Test raises OSError when IGRF coefficient file is missing."""
45
    # Test missing coefficient file failure
46
    with pytest.raises(OSError) as oerr:
47
        Apex()
48
    error_string = "File {:} does not exist".format(igrf_file)
49
    assert str(oerr.value).startswith(error_string)
50
    return
51
52
53
class TestApexInit():
54
    def setup(self):
55
        self.apex_out = None
56
        self.test_date = dt.datetime.utcnow()
57
        self.test_refh = 0
58
        self.bad_file = 'foo/path/to/datafile.blah'
59
60
    def teardown(self):
61
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
62
63
    def eval_date(self):
64
        """Evaluate the times in self.test_date and self.apex_out."""
65
        if isinstance(self.test_date, dt.datetime) \
66
           or isinstance(self.test_date, dt.date):
67
            self.test_date = helpers.toYearFraction(self.test_date)
68
69
        # Assert the times are the same on the order of tens of seconds.
70
        # Necessary to evaluate the current UTC
71
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
72
        return
73
74
    def eval_refh(self):
75
        """Evaluate the reference height in self.refh and self.apex_out."""
76
        eval_str = "".join(["expected reference height [",
77
                            "{:}] not equal to Apex ".format(self.test_refh),
78
                            "reference height ",
79
                            "[{:}]".format(self.apex_out.refh)])
80
        assert self.test_refh == self.apex_out.refh, eval_str
81
        return
82
83
    def test_init_defaults(self):
84
        """Test Apex class default initialization."""
85
        self.apex_out = Apex()
86
        self.eval_date()
87
        self.eval_refh()
88
        return
89
90
    @pytest.mark.parametrize("in_date",
91
                             [2015, 2015.5, dt.date(2015, 1, 1),
92
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
93
    def test_init_date(self, in_date):
94
        """Test Apex class with date initialization."""
95
        self.test_date = in_date
96
        self.apex_out = Apex(date=self.test_date)
97
        self.eval_date()
98
        self.eval_refh()
99
        return
100
101
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
102
    def test_set_epoch(self, new_date):
103
        """Test successful setting of Apex epoch after initialization."""
104
        # Evaluate the default initialization
105
        self.apex_out = Apex()
106
        self.eval_date()
107
        self.eval_refh()
108
109
        # Update the epoch
110
        self.test_date = new_date
111
        self.apex_out.set_epoch(new_date)
112
        self.eval_date()
113
        self.eval_refh()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization."""
119
        self.test_refh = in_refh
120
        self.apex_out = Apex(refh=self.test_refh)
121
        self.eval_date()
122
        self.eval_refh()
123
        return
124
125
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
126
    def test_set_refh(self, new_refh):
127
        """Test the method used to set the reference height after the init."""
128
        # Verify the defaults are set
129
        self.apex_out = Apex(date=self.test_date)
130
        self.eval_date()
131
        self.eval_refh()
132
133
        # Update to a new reference height and test
134
        self.test_refh = new_refh
135
        self.apex_out.set_refh(new_refh)
136
        self.eval_refh()
137
        return
138
139
    def test_init_with_bad_datafile(self):
140
        """Test raises IOError with non-existent datafile input."""
141
        with pytest.raises(IOError) as oerr:
142
            Apex(datafile=self.bad_file)
143
        assert str(oerr.value).startswith('Data file does not exist')
144
        return
145
146
    def test_init_with_bad_fortranlib(self):
147
        """Test raises IOError with non-existent datafile input."""
148
        with pytest.raises(IOError) as oerr:
149
            Apex(fortranlib=self.bad_file)
150
        assert str(oerr.value).startswith('Fortran library does not exist')
151
        return
152
153
154
class TestApexMethod():
155
    """Test the Apex methods."""
156
    def setup(self):
157
        """Initialize all tests."""
158
        self.apex_out = Apex(date=2000, refh=300)
159
        self.in_lat = 60
160
        self.in_lon = 15
161
        self.in_alt = 100
162
163
    def teardown(self):
164
        """Clean up after each test."""
165
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
166
167
    def get_input_args(self, method_name, precision=0.0):
168
        """Set the input arguments for the different Apex methods.
169
170
        Parameters
171
        ----------
172
        method_name : str
173
            Name of the Apex class method
174
        precision : float
175
            Value for the precision (default=0.0)
176
177
        Returns
178
        -------
179
        in_args : list
180
            List of the appropriate input arguments
181
182
        """
183
        in_args = [self.in_lat, self.in_lon, self.in_alt]
184
185
        # Add precision, if needed
186
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
187
                           "_apex2geo"]:
188
            in_args.append(precision)
189
190
        # Add a reference height, if needed
191
        if method_name in ["apxg2all"]:
192
            in_args.append(300)
193
194
        # Add a vector flag, if needed
195
        if method_name in ["apxg2all", "apxg2q"]:
196
            in_args.append(1)
197
198
        return in_args
199
200
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
201
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
202
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
203
                              ("_qd2geo", "apxq2g", slice(None)),
204
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
205
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
206
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
207
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
208
                                  lat, lon):
209
        """Tests Apex/fortran interface consistency for scalars."""
210
        # Set the input coordinates
211
        self.in_lat = lat
212
        self.in_lon = lon
213
214
        # Get the Apex class method and the fortran function call
215
        apex_func = getattr(self.apex_out, apex_method)
216
        fortran_func = getattr(fa, fortran_method)
217
218
        # Get the appropriate input arguments
219
        apex_args = self.get_input_args(apex_method)
220
        fortran_args = self.get_input_args(fortran_method)
221
222
        # Evaluate the equivalent function calls
223
        np.testing.assert_allclose(apex_func(*apex_args),
224
                                   fortran_func(*fortran_args)[fslice])
225
        return
226
227
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
228
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
229
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
230
                              ("_qd2geo", "apxq2g", slice(None)),
231
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
232
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
233
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
234
                                           (180, -180), (-180, 180),
235
                                           (-345, 15), (375, 15)])
236
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
237
                                        fslice, lat, lon1, lon2):
238
        """Tests Apex/fortran interface consistency for longitude rollover."""
239
        # Set the fixed input coordinate
240
        self.in_lat = lat
241
242
        # Get the Apex class method and the fortran function call
243
        apex_func = getattr(self.apex_out, apex_method)
244
        fortran_func = getattr(fa, fortran_method)
245
246
        # Get the appropriate input arguments
247
        self.in_lon = lon1
248
        apex_args = self.get_input_args(apex_method)
249
250
        self.in_lon = lon2
251
        fortran_args = self.get_input_args(fortran_method)
252
253
        # Evaluate the equivalent function calls
254
        np.testing.assert_allclose(apex_func(*apex_args),
255
                                   fortran_func(*fortran_args)[fslice])
256
        return
257
258
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
259
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
260
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
261
                              ("_qd2geo", "apxq2g", slice(None)),
262
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
263
    def test_fortran_array_input(self, apex_method, fortran_method, fslice):
264
        """Tests Apex/fortran interface consistency for array input."""
265
        # Get the Apex class method and the fortran function call
266
        apex_func = getattr(self.apex_out, apex_method)
267
        fortran_func = getattr(fa, fortran_method)
268
269
        # Set up the input arrays
270
        ref_lat = np.array([0, 30, 60, 90])
271
        ref_alt = np.array([100, 200, 300, 400])
272
        self.in_lat = ref_lat.reshape((2, 2))
273
        self.in_alt = ref_alt.reshape((2, 2))
274
        apex_args = self.get_input_args(apex_method)
275
276
        # Get the Apex class results
277
        aret = apex_func(*apex_args)
278
279
        # Get the fortran function results
280
        flats = list()
281
        flons = list()
282
283
        for i, lat in enumerate(ref_lat):
284
            self.in_lat = lat
285
            self.in_alt = ref_alt[i]
286
            fortran_args = self.get_input_args(fortran_method)
287
            fret = fortran_func(*fortran_args)[fslice]
288
            flats.append(fret[0])
289
            flons.append(fret[1])
290
291
        flats = np.array(flats)
292
        flons = np.array(flons)
293
294
        # Evaluate results
295
        try:
296
            # This returned value is array of floats
297
            np.testing.assert_allclose(aret[0].astype(float),
298
                                       flats.reshape((2, 2)).astype(float))
299
            np.testing.assert_allclose(aret[1].astype(float),
300
                                       flons.reshape((2, 2)).astype(float))
301
        except ValueError:
302
            # This returned value is array of arrays
303
            alats = aret[0].reshape((4,))
304
            alons = aret[1].reshape((4,))
305
            for i, flat in enumerate(flats):
306
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
307
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
308
309
        return
310
311
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
312
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
313
    def test_geo2apexall_scalar(self, lat, lon):
314
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
315
        # Get the Apex and Fortran results
316
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
317
        fret = fa.apxg2all(lat, lon, self.in_alt, 300, 1)
318
319
        # Evaluate each element in the results
320
        for aval, fval in zip(aret, fret):
321
            np.testing.assert_allclose(aval, fval)
322
323
    def test_geo2apexall_array(self):
324
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
325
        # Set the input
326
        self.in_lat = np.array([0, 30, 60, 90])
327
        self.in_alt = np.array([100, 200, 300, 400])
328
329
        # Get the Apex class results
330
        aret = self.apex_out._geo2apexall(self.in_lat.reshape((2, 2)),
331
                                          self.in_lon,
332
                                          self.in_alt.reshape((2, 2)))
333
334
        # For each lat/alt pair, get the Fortran results
335
        fret = list()
336
        for i, lat in enumerate(self.in_lat):
337
            fret.append(fa.apxg2all(lat, self.in_lon, self.in_alt[i], 300, 1))
338
339
        # Cycle through all returned values
340
        for i, ret in enumerate(aret):
341
            try:
342
                # This returned value is array of floats
343
                np.testing.assert_allclose(ret.astype(float),
344
                                           np.array([[fret[0][i], fret[1][i]],
345
                                                     [fret[2][i], fret[3][i]]],
346
                                                    dtype=float))
347
            except ValueError:
348
                # This returned value is array of arrays
349
                ret = ret.reshape((4,))
350
                for j, single_fret in enumerate(fret):
351
                    np.testing.assert_allclose(ret[j], single_fret[i])
352
        return
353
354
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
355
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
356
    def test_convert_consistency(self, in_coord, out_coord):
357
        """Test the self-consistency of the Apex convert method."""
358
        if in_coord == out_coord:
359
            pytest.skip("Test not needed for same src and dest coordinates")
360
361
        # Define the method name
362
        method_name = "2".join([in_coord, out_coord])
363
364
        # Get the method and method inputs
365
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
366
        apex_args = self.get_input_args(method_name)
367
        apex_method = getattr(self.apex_out, method_name)
368
369
        # Define the slice needed to get equivalent output from the named method
370
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
371
372
        # Get output using convert and named method
373
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
374
                                            out_coord, **convert_kwargs)
375
        method_out = apex_method(*apex_args)[mslice]
376
377
        # Compare both outputs, should be identical
378
        np.testing.assert_allclose(convert_out, method_out)
379
        return
380
381
    @pytest.mark.parametrize("bound_lat", [90, -90])
382
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
383
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
384
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
385
        """Test the conversion at the latitude boundary, with allowed excess."""
386
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
387
388
        # Get the two outputs, slight tolerance outside of boundary allowed
389
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
390
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
391
392
        # Test the outputs
393
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
394
        return
395
396
    def test_convert_qd2apex_at_equator(self):
397
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
398
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
399
                                       height=320.0)
400
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
401
                                          dest='apex', height=320.0)
402
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
403
        return
404
405
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
406
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
407
    def test_convert_withnan(self, src, dest):
408
        """Test Apex.convert success with NaN input."""
409
        if src == dest:
410
            pytest.skip("Test not needed for same src and dest coordinates")
411
412
        num_nans = 5
413
        in_loc = np.arange(0, 10, dtype=float)
414
        in_loc[:num_nans] = np.nan
415
416
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
417
418
        for out in out_loc:
419
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
420
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
421
422
        return
423
424
    @pytest.mark.parametrize("bad_lat", [91, -91])
425
    def test_convert_invalid_lat(self, bad_lat):
426
        """Test convert raises ValueError for invalid latitudes."""
427
428
        with pytest.raises(ValueError):
429
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
430
        return
431
432
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar")])
433
    def test_convert_invalid_transformation(self, coords):
434
        """Test raises NotImplementedError for bad coordinates."""
435
        with pytest.raises(NotImplementedError):
436
            self.apex_out.convert(0, 0, *coords)
437
        return
438
439
    @pytest.mark.parametrize("method_name, out_comp",
440
                             [("geo2apex",
441
                               (55.94841766357422, 94.10684204101562)),
442
                              ("apex2geo",
443
                               (51.476322174072266, -66.22817993164062,
444
                                5.727287771151168e-06)),
445
                              ("geo2qd",
446
                               (56.531288146972656, 94.10684204101562)),
447
                              ("apex2qd", (60.498401178276744, 15.0)),
448
                              ("qd2apex", (59.49138097045895, 15.0))])
449
    def test_method_scalar_input(self, method_name, out_comp):
450
        """Test the user method against set values with scalars."""
451
        # Get the desired methods
452
        user_method = getattr(self.apex_out, method_name)
453
454
        # Get the user output
455
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
456
457
        # Evaluate the user output
458
        np.testing.assert_allclose(user_out, out_comp)
459
460
        for out_val in user_out:
461
            assert np.asarray(out_val).shape == (), "output is not a scalar"
462
        return
463
464
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
465
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
466
    @pytest.mark.parametrize("method_args, out_shape",
467
                             [([[60, 60], 15, 100], (2,)),
468
                              ([60, [15, 15], 100], (2,)),
469
                              ([60, 15, [100, 100]], (2,)),
470
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
471
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
472
                                    out_shape):
473
        """Test the user method with inputs that require some broadcasting."""
474
        if in_coord == out_coord:
475
            pytest.skip("Test not needed for same src and dest coordinates")
476
477
        # Get the desired methods
478
        method_name = "2".join([in_coord, out_coord])
479
        user_method = getattr(self.apex_out, method_name)
480
481
        # Get the user output
482
        user_out = user_method(*method_args)
483
484
        # Evaluate the user output
485
        for out_val in user_out:
486
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
487
            assert out_val.shape == out_shape
488
        return
489
490
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
491
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
492
    @pytest.mark.parametrize("bad_lat", [91, -91])
493
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
494
        """Test convert raises ValueError for invalid latitudes."""
495
        if in_coord == out_coord:
496
            pytest.skip("Test not needed for same src and dest coordinates")
497
498
        # Get the desired methods
499
        method_name = "2".join([in_coord, out_coord])
500
        user_method = getattr(self.apex_out, method_name)
501
502
        with pytest.raises(ValueError):
503
            user_method(bad_lat, 15, 100)
504
        return
505
506
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
507
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
508
    @pytest.mark.parametrize("bound_lat", [90, -90])
509
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
510
        """Test user methods at the latitude boundary, with allowed excess."""
511
        if in_coord == out_coord:
512
            pytest.skip("Test not needed for same src and dest coordinates")
513
514
        # Get the desired methods
515
        method_name = "2".join([in_coord, out_coord])
516
        user_method = getattr(self.apex_out, method_name)
517
518
        # Get a latitude just beyond the limit
519
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
520
521
        # Get the two outputs, slight tolerance outside of boundary allowed
522
        bound_out = user_method(bound_lat, 0, 100)
523
        excess_out = user_method(excess_lat, 0, 100)
524
525
        # Test the outputs
526
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
527
        return
528
529
    def test_geo2apex_undefined_warning(self):
530
        """Test geo2apex warning and fill values for an undefined location."""
531
532
        # Update the apex object
533
        self.apex_out = Apex(date=2000, refh=10000)
534
535
        # Get the output and the warnings
536
        with warnings.catch_warnings(record=True) as warn_rec:
537
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
538
539
        assert np.isnan(user_lat)
540
        assert np.isfinite(user_lon)
541
        assert len(warn_rec) == 1
542
        assert issubclass(warn_rec[-1].category, UserWarning)
543
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
544
        return
545
546
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
547
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
548
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
549
        """Test quasi-dipole success with a height close to the reference."""
550
        qd_method = getattr(self.apex_out, method_name)
551
        in_args = [0, 15, self.apex_out.refh + delta_h]
552
        out_coords = qd_method(*in_args)
553
554
        for i, out_val in enumerate(out_coords):
555
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
556
        return
557
558
    @pytest.mark.parametrize("method_name, hinc, msg",
559
                             [("apex2qd", 1.0, "is > apex height"),
560
                              ("qd2apex", -1.0, "is < reference height")])
561
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
562
        """Quasi-dipole raises ApexHeightError when height above reference."""
563
        qd_method = getattr(self.apex_out, method_name)
564
565
        with pytest.raises(ApexHeightError) as aerr:
566
            qd_method(0, 15, self.apex_out.refh + hinc)
567
568
        assert str(aerr).find(msg) > 0
569
        return
570
571
572
class TestApexMLTMethods():
573
    """Test the Apex Magnetic Local Time (MLT) methods."""
574
    def setup(self):
575
        """Initialize all tests."""
576
        self.apex_out = Apex(date=2000, refh=300)
577
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
578
579
    def teardown(self):
580
        """Clean up after each test."""
581
        del self.apex_out, self.in_time
582
583
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
584
    def test_convert_to_mlt(self, in_coord):
585
        """Test the conversions to MLT using Apex convert."""
586
587
        # Get the magnetic longitude from the appropriate method
588
        if in_coord == "geo":
589
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
590
            mlon = apex_method(60, 15, 100)[1]
591
        else:
592
            mlon = 15
593
594
        # Get the output MLT values
595
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
596
                                            height=100, ssheight=2e5,
597
                                            datetime=self.in_time)[1]
598
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
599
600
        # Test the outputs
601
        np.testing.assert_allclose(convert_mlt, method_mlt)
602
        return
603
604
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
605
    def test_convert_mlt_to_lon(self, out_coord):
606
        """Test the conversions from MLT using Apex convert."""
607
        # Get the output longitudes
608
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
609
                                            height=100, ssheight=2e5,
610
                                            datetime=self.in_time,
611
                                            precision=1e-2)
612
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
613
614
        if out_coord == "geo":
615
            method_out = self.apex_out.apex2geo(60, mlon, 100,
616
                                                precision=1e-2)[:-1]
617
        elif out_coord == "qd":
618
            method_out = self.apex_out.apex2qd(60, mlon, 100)
619
        else:
620
            method_out = (60, mlon)
621
622
        # Evaluate the outputs
623
        np.testing.assert_allclose(convert_out, method_out)
624
        return
625
626
    def test_convert_geo2mlt_nodate(self):
627
        """Test convert from geo to MLT raises ValueError with no datetime."""
628
        with pytest.raises(ValueError):
629
            self.apex_out.convert(60, 15, 'geo', 'mlt')
630
        return
631
632
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
633
                             [({}, 23.019629923502603),
634
                              ({"ssheight": 100000}, 23.026712036132814)])
635
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
636
        """Test mlon2mlt with scalar inputs."""
637
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
638
639
        np.testing.assert_allclose(mlt, test_mlt)
640
        assert np.asarray(mlt).shape == ()
641
        return
642
643
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
644
                             [({}, 14.705535888671875),
645
                              ({"ssheight": 100000}, 14.599319458007812)])
646
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
647
        """Test mlt2mlon with scalar inputs."""
648
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
649
650
        np.testing.assert_allclose(mlon, test_mlon)
651
        assert np.asarray(mlon).shape == ()
652
        return
653
654
    @pytest.mark.parametrize("mlon,test_mlt",
655
                             [([0, 180], [23.019261, 11.019261]),
656
                              (np.array([0, 180]), [23.019261, 11.019261]),
657
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
658
                                                      [23.019261, 11.019261]]),
659
                              (range(0, 361, 30),
660
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
661
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
662
                                19.01963, 21.01963, 23.01963])])
663
    def test_mlon2mlt_array(self, mlon, test_mlt):
664
        """Test mlon2mlt with array inputs."""
665
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
666
667
        assert mlt.shape == np.asarray(test_mlt).shape
668
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
669
        return
670
671
    @pytest.mark.parametrize("mlt,test_mlon",
672
                             [([0, 12], [14.705551, 194.705551]),
673
                              (np.array([0, 12]), [14.705551, 194.705551]),
674
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
675
                                                    [14.705551, 194.705551]]),
676
                              (range(0, 25, 2),
677
                               [14.705551, 44.705551, 74.705551, 104.705551,
678
                                134.705551, 164.705551, 194.705551, 224.705551,
679
                                254.705551, 284.705551, 314.705551, 344.705551,
680
                                14.705551])])
681
    def test_mlt2mlon_array(self, mlt, test_mlon):
682
        """Test mlt2mlon with array inputs."""
683
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
684
685
        assert mlon.shape == np.asarray(test_mlon).shape
686
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
687
        return
688
689
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
690
    def test_mlon2mlt_diffdates(self, method_name):
691
        """Test that MLT varies with universal time."""
692
        apex_method = getattr(self.apex_out, method_name)
693
        mlt1 = apex_method(0, self.in_time)
694
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
695
696
        assert mlt1 != mlt2
697
        return
698
699
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
700
    def test_mlon2mlt_offset(self, mlt_offset):
701
        """Test the time wrapping logic for the MLT."""
702
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
703
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
704
                                      self.in_time) + mlt_offset
705
706
        np.testing.assert_allclose(mlt1, mlt2)
707
        return
708
709
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
710
    def test_mlt2mlon_offset(self, mlon_offset):
711
        """Test the time wrapping logic for the magnetic longitude."""
712
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
713
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
714
                                       self.in_time) - mlon_offset
715
716
        np.testing.assert_allclose(mlon1, mlon2)
717
        return
718
719
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
720
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
721
    def test_convert_and_return(self, order, start_val):
722
        """Test the conversion to magnetic longitude or MLT and back again."""
723
        first_method = getattr(self.apex_out, "2".join(order))
724
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
725
726
        middle_val = first_method(start_val, self.in_time)
727
        end_val = second_method(middle_val, self.in_time)
728
729
        np.testing.assert_allclose(start_val, end_val)
730
        return
731
732
733
class TestApexMapMethods():
734
    """Test the Apex height mapping methods."""
735
    def setup(self):
736
        """Initialize all tests."""
737
        self.apex_out = Apex(date=2000, refh=300)
738
739
    def teardown(self):
740
        """Clean up after each test."""
741
        del self.apex_out
742
743
    @pytest.mark.parametrize("in_args,test_mapped",
744
                             [([60, 15, 100, 10000],
745
                               [31.841466903686523, 17.916635513305664,
746
                                1.7075473124350538e-6]),
747
                              ([30, 170, 100, 500, False, 1e-2],
748
                               [25.727270126342773, 169.60546875,
749
                                0.00017573432705830783]),
750
                              ([60, 15, 100, 10000, True],
751
                               [-25.424888610839844, 27.310426712036133,
752
                                1.2074182222931995e-6]),
753
                              ([30, 170, 100, 500, True, 1e-2],
754
                               [-13.76642894744873, 164.24259948730469,
755
                                0.00056820799363777041])])
756
    def test_map_to_height(self, in_args, test_mapped):
757
        """Test the map_to_height function."""
758
        mapped = self.apex_out.map_to_height(*in_args)
759
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
760
        return
761
762
    def test_map_to_height_same_height(self):
763
        """Test the map_to_height function when mapping to same height."""
764
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
765
                                             precision=1e-10)
766
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
767
                                   rtol=1e-5)
768
        return
769
770
    @pytest.mark.parametrize('ivec', range(0, 4))
771
    def test_map_to_height_array_location(self, ivec):
772
        """Test map_to_height with array input."""
773
        # Set the base input and output values
774
        in_args = [60, 15, 100, 100]
775
        test_mapped = np.full(shape=(2, 3),
776
                              fill_value=[60, 15.00000381, 0.0]).transpose()
777
778
        # Update inputs for one vectorized value
779
        in_args[ivec] = [in_args[ivec], in_args[ivec]]
780
781
        # Calculate and test function
782
        mapped = self.apex_out.map_to_height(*in_args)
783
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
784
        return
785
786
    @pytest.mark.parametrize("method_name,in_args",
787
                             [("map_to_height", [0, 15, 100, 10000]),
788
                              ("map_E_to_height",
789
                               [0, 15, 100, 10000, [1, 2, 3]]),
790
                              ("map_V_to_height",
791
                               [0, 15, 100, 10000, [1, 2, 3]])])
792
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
793
        """Test map_to_height raises ApexHeightError."""
794
        apex_method = getattr(self.apex_out, method_name)
795
796
        with pytest.raises(ApexHeightError) as aerr:
797
            apex_method(*in_args)
798
799
        assert aerr.match("is > apex height")
800
        return
801
802
    @pytest.mark.parametrize("method_name",
803
                             ["map_E_to_height", "map_V_to_height"])
804
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
805
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
806
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
807
        """Test map_to_height raises ApexHeightError."""
808
        apex_method = getattr(self.apex_out, method_name)
809
        in_args = [60, 15, 100, 500, ev_input]
810
        with pytest.raises(ValueError) as verr:
811
            apex_method(*in_args)
812
813
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
814
        return
815
816
    @pytest.mark.parametrize("in_args,test_mapped",
817
                             [([60, 15, 100, 500, [1, 2, 3]],
818
                               [0.71152183, 2.35624876, 0.57260784]),
819
                              ([60, 15, 100, 500, [2, 3, 4]],
820
                               [1.56028502, 3.43916636, 0.78235384]),
821
                              ([60, 15, 100, 1000, [1, 2, 3]],
822
                               [0.67796492, 2.08982134, 0.55860785]),
823
                              ([60, 15, 200, 500, [1, 2, 3]],
824
                               [0.72377397, 2.42737471, 0.59083726]),
825
                              ([60, 30, 100, 500, [1, 2, 3]],
826
                               [0.68626344, 2.37530133, 0.60060124]),
827
                              ([70, 15, 100, 500, [1, 2, 3]],
828
                               [0.72760378, 2.18082305, 0.29141979])])
829
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
830
        """Test mapping of E-field to a specified height."""
831
        mapped = self.apex_out.map_E_to_height(*in_args)
832
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
833
        return
834
835 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
836
    def test_map_E_to_height_array_location(self, ivec):
837
        """Test mapping of E-field to a specified height with array input."""
838
        # Set the base input and output values
839
        efield = np.array([[1, 2, 3]] * 2).transpose()
840
        in_args = [60, 15, 100, 500, efield]
841
        test_mapped = np.full(shape=(2, 3),
842
                              fill_value=[0.71152183, 2.35624876,
843
                                          0.57260784]).transpose()
844
845
        # Update inputs for one vectorized value if this is a location input
846
        if ivec < 4:
847
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
848
849
        # Get the mapped output and test the results
850
        mapped = self.apex_out.map_E_to_height(*in_args)
851
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
852
        return
853
854
    @pytest.mark.parametrize("in_args,test_mapped",
855
                             [([60, 15, 100, 500, [1, 2, 3]],
856
                               [0.81971957, 2.84512495, 0.69545001]),
857
                              ([60, 15, 100, 500, [2, 3, 4]],
858
                               [1.83027746, 4.14346436, 0.94764179]),
859
                              ([60, 15, 100, 1000, [1, 2, 3]],
860
                               [0.92457698, 3.14997661, 0.85135187]),
861
                              ([60, 15, 200, 500, [1, 2, 3]],
862
                               [0.80388262, 2.79321504, 0.68285158]),
863
                              ([60, 30, 100, 500, [1, 2, 3]],
864
                               [0.76141245, 2.87884673, 0.73655941]),
865
                              ([70, 15, 100, 500, [1, 2, 3]],
866
                               [0.84681866, 2.5925821,  0.34792655])])
867
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
868
        """Test mapping of velocity to a specified height."""
869
        mapped = self.apex_out.map_V_to_height(*in_args)
870
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
871
        return
872
873 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
874
    def test_map_V_to_height_array_location(self, ivec):
875
        """Test mapping of velocity to a specified height with array input."""
876
        # Set the base input and output values
877
        evel = np.array([[1, 2, 3]] * 2).transpose()
878
        in_args = [60, 15, 100, 500, evel]
879
        test_mapped = np.full(shape=(2, 3),
880
                              fill_value=[0.81971957, 2.84512495,
881
                                          0.69545001]).transpose()
882
883
        # Update inputs for one vectorized value if this is a location input
884
        if ivec < 4:
885
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
886
887
        # Get the mapped output and test the results
888
        mapped = self.apex_out.map_V_to_height(*in_args)
889
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
890
        return
891
892
893
class TestApexBasevectorMethods():
894
    """Test the Apex height base vector methods."""
895
    def setup(self):
896
        """Initialize all tests."""
897
        self.apex_out = Apex(date=2000, refh=300)
898
        self.lat = 60
899
        self.lon = 15
900
        self.height = 100
901
        self.test_basevec = None
902
903
    def teardown(self):
904
        """Clean up after each test."""
905
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
906
907
    def get_comparison_results(self, bv_coord, coords, precision):
908
        """Get the base vector results using the hidden function for comparison.
909
910
        Parameters
911
        ----------
912
        bv_coord : str
913
            Basevector coordinate scheme, expects on of 'apex', 'qd',
914
            or 'bvectors_apex'
915
        coords : str
916
            Expects one of 'geo', 'apex', or 'qd'
917
        precision : float
918
            Float specifiying precision
919
920
        """
921
        if coords == "geo":
922
            glat = self.lat
923
            glon = self.lon
924
        else:
925
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
926
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
927
                                        precision=precision)
928
929
        if bv_coord == 'qd':
930
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
931
        elif bv_coord == 'apex':
932
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
933
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
934
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
935
        else:
936
            # These are set results that need to be updated with IGRF
937
            if coords == "geo":
938
                self.test_basevec = (
939
                    np.array([4.42368795e-05, 4.42368795e-05]),
940
                    np.array([[0.01047826, 0.01047826],
941
                              [0.33089194, 0.33089194],
942
                              [-1.04941, -1.04941]]),
943
                    np.array([5.3564698e-05, 5.3564698e-05]),
944
                    np.array([[0.00865356, 0.00865356],
945
                              [0.27327004, 0.27327004],
946
                              [-0.8666646, -0.8666646]]))
947
            elif coords == "apex":
948
                self.test_basevec = (
949
                    np.array([4.48672735e-05, 4.48672735e-05]),
950
                    np.array([[-0.12510721, -0.12510721],
951
                              [0.28945938, 0.28945938],
952
                              [-1.1505738, -1.1505738]]),
953
                    np.array([6.38577444e-05, 6.38577444e-05]),
954
                    np.array([[-0.08790194, -0.08790194],
955
                              [0.2033779, 0.2033779],
956
                              [-0.808408, -0.808408]]))
957
            else:
958
                self.test_basevec = (
959
                    np.array([4.46348578e-05, 4.46348578e-05]),
960
                    np.array([[-0.12642345, -0.12642345],
961
                              [0.29695055, 0.29695055],
962
                              [-1.1517885, -1.1517885]]),
963
                    np.array([6.38626285e-05, 6.38626285e-05]),
964
                    np.array([[-0.08835986, -0.08835986],
965
                              [0.20754464, 0.20754464],
966
                              [-0.8050078, -0.8050078]]))
967
968
        return
969
970
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
971
    @pytest.mark.parametrize("coords,precision",
972
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
973
    def test_basevectors_scalar(self, bv_coord, coords, precision):
974
        """Test the base vector calculations with scalars."""
975
        # Get the base vectors
976
        base_method = getattr(self.apex_out,
977
                              "basevectors_{:s}".format(bv_coord))
978
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
979
                              precision=precision)
980
        self.get_comparison_results(bv_coord, coords, precision)
981
        if bv_coord == "apex":
982
            basevec = list(basevec)
983
            for i in range(4):
984
                # Not able to compare indices 2, 3, 4, and 5
985
                basevec.pop(2)
986
987
        # Test the results
988
        for i, vec in enumerate(basevec):
989
            np.testing.assert_allclose(vec, self.test_basevec[i])
990
        return
991
992
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
993
    def test_basevectors_scalar_shape(self, bv_coord):
994
        """Test the shape of the scalar output."""
995
        base_method = getattr(self.apex_out,
996
                              "basevectors_{:s}".format(bv_coord))
997
        basevec = base_method(self.lat, self.lon, self.height)
998
999
        for i, vec in enumerate(basevec):
1000
            if i < 2:
1001
                assert vec.shape == (2,)
1002
            else:
1003
                assert vec.shape == (3,)
1004
        return
1005
1006
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1007
    @pytest.mark.parametrize("ivec", range(3))
1008
    def test_basevectors_array(self, bv_coord, ivec):
1009
        """Test the output shape for array inputs."""
1010
        # Define the input arguments
1011
        in_args = [self.lat, self.lon, self.height]
1012
        in_args[ivec] = [in_args[ivec] for i in range(4)]
1013
1014
        # Get the basevectors
1015
        base_method = getattr(self.apex_out,
1016
                              "basevectors_{:s}".format(bv_coord))
1017
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1018
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1019
        if bv_coord == "apex":
1020
            basevec = list(basevec)
1021
            for i in range(4):
1022
                # Not able to compare indices 2, 3, 4, and 5
1023
                basevec.pop(2)
1024
1025
        # Evaluate the shape and the values
1026
        for i, vec in enumerate(basevec):
1027
            idim = 2 if i < 2 else 3
1028
            assert vec.shape == (idim, 4)
1029
            assert np.all(self.test_basevec[i][0] == vec[0])
1030
            assert np.all(self.test_basevec[i][1] == vec[1])
1031
        return
1032
1033
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1034
    def test_bvectors_apex(self, coords):
1035
        """Test the bvectors_apex method."""
1036
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1037
                   [self.height, self.height]]
1038
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1039
1040
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1041
                                              precision=1e-10)
1042
        for i, vec in enumerate(basevec):
1043
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1044
                                                 decimal=5)
1045
        return
1046
1047
    def test_basevectors_apex_extra_values(self):
1048
        """Test specific values in the apex base vector output."""
1049
        # Set the testing arrays
1050
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1051
                             np.array([0.939012, 0.073416, -0.07342]),
1052
                             np.array([0.055389, 1.004155, 0.257594]),
1053
                             np.array([0, 0, 1.065135])]
1054
1055
        # Get the desired output
1056
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1057
1058
        # Test the values not covered by `test_basevectors_scalar`
1059
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1060
            np.testing.assert_allclose(basevec[ibase],
1061
                                       self.test_basevec[itest], rtol=1e-4)
1062
        return
1063
1064
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1065
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1066
    def test_basevectors_apex_delta(self, lat, lon):
1067
        """Test that vectors are calculated correctly."""
1068
        # Get the apex base vectors and sort them for easy testing
1069
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1070
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1071
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1072
        gvec = [g1, g2, g3]
1073
        dvec = [d1, d2, d3]
1074
        evec = [e1, e2, e3]
1075
1076
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1077
            delta = 1 if idelta == jdelta else 0
1078
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1079
                                       delta, rtol=0, atol=1e-5)
1080
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1081
                                       delta, rtol=0, atol=1e-5)
1082
        return
1083
1084
    def test_basevectors_apex_invalid_scalar(self):
1085
        """Test warning and fill values for base vectors with bad inputs."""
1086
        self.apex_out = Apex(date=2000, refh=10000)
1087
        invalid = np.full(shape=(3,), fill_value=np.nan)
1088
1089
        # Get the output and the warnings
1090
        with warnings.catch_warnings(record=True) as warn_rec:
1091
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1092
1093
        for i, bvec in enumerate(basevec):
1094
            if i < 2:
1095
                assert not np.allclose(bvec, invalid[:2])
1096
            else:
1097
                np.testing.assert_allclose(bvec, invalid)
1098
1099
        assert issubclass(warn_rec[-1].category, UserWarning)
1100
        assert 'set to NaN where' in str(warn_rec[-1].message)
1101
        return
1102
1103
1104
class TestApexGetMethods():
1105
    """Test the Apex `get` methods."""
1106
    def setup(self):
1107
        """Initialize all tests."""
1108
        self.apex_out = Apex(date=2000, refh=300)
1109
1110
    def teardown(self):
1111
        """Clean up after each test."""
1112
        del self.apex_out
1113
1114
    @pytest.mark.parametrize("alat, aheight", [(10, 507.409702543805),
1115
                                               (60, 20313.026999999987)])
1116
    def test_get_apex(self, alat, aheight):
1117
        """Test the apex height retrieval results."""
1118
        alt = self.apex_out.get_apex(alat)
1119
        np.testing.assert_allclose(alt, aheight)
1120
        return
1121
1122
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1123
                             [([80], [100], [300], 5.100682377815247e-05),
1124
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1125
                               np.array([4.18657154e-05, 5.11118114e-05,
1126
                                         4.91969854e-05, 5.10519207e-05,
1127
                                         4.90054816e-05])),
1128
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1129
    def test_get_babs(self, glat, glon, height, test_bmag):
1130
        """Test the method to get the magnitude of the magnetic field."""
1131
        bmag = self.apex_out.get_babs(glat, glon, height)
1132
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1133
        return
1134
1135
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1136
    def test_get_with_invalid_lat(self, bad_lat):
1137
        """Test get methods raise ValueError for invalid latitudes."""
1138
1139
        with pytest.raises(ValueError):
1140
            self.apex_out.get_apex(bad_lat)
1141
        return
1142
1143
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1144
    def test_get_at_lat_boundary(self, bound_lat):
1145
        """Test get methods at the latitude boundary, with allowed excess."""
1146
        # Get a latitude just beyond the limit
1147
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1148
1149
        # Get the two outputs, slight tolerance outside of boundary allowed
1150
        bound_out = self.apex_out.get_apex(bound_lat)
1151
        excess_out = self.apex_out.get_apex(excess_lat)
1152
1153
        # Test the outputs
1154
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1155
        return
1156