Passed
Pull Request — develop (#76)
by Angeline
02:48 queued 01:30
created

test_Apex.TestApexInit.test_repr_eval()   A

Complexity

Conditions 1

Size

Total Lines 15
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 15
rs 9.95
c 0
b 0
f 0
cc 1
nop 1
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import warnings
20
21
import apexpy
22
23
24
@pytest.fixture()
25
def igrf_file():
26
    """A fixture for handling the coefficient file."""
27
    # Ensure the coefficient file exists
28
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
29
                                 'igrf13coeffs.txt')
30
    tmp_file = "temp_coeff.txt"
31
    assert os.path.isfile(original_file)
32
33
    # Move the coefficient file
34
    os.rename(original_file, tmp_file)
35
    yield original_file
36
37
    # Move the coefficient file back
38
    os.rename(tmp_file, original_file)
39
    return
40
41
42
def test_set_epoch_file_error(igrf_file):
43
    """Test raises OSError when IGRF coefficient file is missing."""
44
    # Test missing coefficient file failure
45
    with pytest.raises(OSError) as oerr:
46
        apexpy.Apex()
47
    error_string = "File {:} does not exist".format(igrf_file)
48
    assert str(oerr.value).startswith(error_string)
49
    return
50
51
52
class TestApexInit():
53
    def setup(self):
54
        self.apex_out = None
55
        self.test_date = dt.datetime.utcnow()
56
        self.test_refh = 0
57
        self.bad_file = 'foo/path/to/datafile.blah'
58
59
    def teardown(self):
60
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
61
62
    def eval_date(self):
63
        """Evaluate the times in self.test_date and self.apex_out."""
64
        if isinstance(self.test_date, dt.datetime) \
65
           or isinstance(self.test_date, dt.date):
66
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
67
68
        # Assert the times are the same on the order of tens of seconds.
69
        # Necessary to evaluate the current UTC
70
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
71
        return
72
73
    def eval_refh(self):
74
        """Evaluate the reference height in self.refh and self.apex_out."""
75
        eval_str = "".join(["expected reference height [",
76
                            "{:}] not equal to Apex ".format(self.test_refh),
77
                            "reference height ",
78
                            "[{:}]".format(self.apex_out.refh)])
79
        assert self.test_refh == self.apex_out.refh, eval_str
80
        return
81
82
    def test_init_defaults(self):
83
        """Test Apex class default initialization."""
84
        self.apex_out = apexpy.Apex()
85
        self.eval_date()
86
        self.eval_refh()
87
        return
88
89
    @pytest.mark.parametrize("in_date",
90
                             [2015, 2015.5, dt.date(2015, 1, 1),
91
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
92
    def test_init_date(self, in_date):
93
        """Test Apex class with date initialization."""
94
        self.test_date = in_date
95
        self.apex_out = apexpy.Apex(date=self.test_date)
96
        self.eval_date()
97
        self.eval_refh()
98
        return
99
100
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
101
    def test_set_epoch(self, new_date):
102
        """Test successful setting of Apex epoch after initialization."""
103
        # Evaluate the default initialization
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
108
        # Update the epoch
109
        ref_apex = eval(self.apex_out.__repr__())
110
        self.apex_out.set_epoch(new_date)
111
        assert ref_apex != self.apex_out
112
        self.test_date = new_date
113
        self.eval_date()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization."""
119
        self.test_refh = in_refh
120
        self.apex_out = apexpy.Apex(refh=self.test_refh)
121
        self.eval_date()
122
        self.eval_refh()
123
        return
124
125
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
126
    def test_set_refh(self, new_refh):
127
        """Test the method used to set the reference height after the init."""
128
        # Verify the defaults are set
129
        self.apex_out = apexpy.Apex(date=self.test_date)
130
        self.eval_date()
131
        self.eval_refh()
132
133
        # Update to a new reference height and test
134
        ref_apex = eval(self.apex_out.__repr__())
135
        self.apex_out.set_refh(new_refh)
136
137
        if self.test_refh == new_refh:
138
            assert ref_apex == self.apex_out
139
        else:
140
            assert ref_apex != self.apex_out
141
            self.test_refh = new_refh
142
        self.eval_refh()
143
        return
144
145
    def test_init_with_bad_datafile(self):
146
        """Test raises IOError with non-existent datafile input."""
147
        with pytest.raises(IOError) as oerr:
148
            apexpy.Apex(datafile=self.bad_file)
149
        assert str(oerr.value).startswith('Data file does not exist')
150
        return
151
152
    def test_init_with_bad_fortranlib(self):
153
        """Test raises IOError with non-existent datafile input."""
154
        with pytest.raises(IOError) as oerr:
155
            apexpy.Apex(fortranlib=self.bad_file)
156
        assert str(oerr.value).startswith('Fortran library does not exist')
157
        return
158
159
    def test_repr_eval(self):
160
        """Test the Apex.__repr__ results."""
161
        # Initialize the apex object
162
        self.apex_out = apexpy.Apex()
163
        self.eval_date()
164
        self.eval_refh()
165
166
        # Get and test the repr string
167
        out_str = self.apex_out.__repr__()
168
        assert out_str.find("apexpy.Apex(") == 0
169
170
        # Test the ability to re-create the apex object from the repr string
171
        new_apex = eval(out_str)
172
        assert new_apex == self.apex_out
173
        return
174
175
    def test_str_eval(self):
176
        """Test the Apex.__str__ results."""
177
        # Initialize the apex object
178
        self.apex_out = apexpy.Apex()
179
        self.eval_date()
180
        self.eval_refh()
181
182
        # Get and test the printed string
183
        out_str = self.apex_out.__str__()
184
        assert out_str.find("Decimal year") > 0
185
        return
186
187
188
class TestApexMethod():
189
    """Test the Apex methods."""
190
    def setup(self):
191
        """Initialize all tests."""
192
        self.apex_out = apexpy.Apex(date=2000, refh=300)
193
        self.in_lat = 60
194
        self.in_lon = 15
195
        self.in_alt = 100
196
197
    def teardown(self):
198
        """Clean up after each test."""
199
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
200
201
    def get_input_args(self, method_name, precision=0.0):
202
        """Set the input arguments for the different Apex methods.
203
204
        Parameters
205
        ----------
206
        method_name : str
207
            Name of the Apex class method
208
        precision : float
209
            Value for the precision (default=0.0)
210
211
        Returns
212
        -------
213
        in_args : list
214
            List of the appropriate input arguments
215
216
        """
217
        in_args = [self.in_lat, self.in_lon, self.in_alt]
218
219
        # Add precision, if needed
220
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
221
                           "_apex2geo"]:
222
            in_args.append(precision)
223
224
        # Add a reference height, if needed
225
        if method_name in ["apxg2all"]:
226
            in_args.append(300)
227
228
        # Add a vector flag, if needed
229
        if method_name in ["apxg2all", "apxg2q"]:
230
            in_args.append(1)
231
232
        return in_args
233
234
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
235
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
236
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
237
                              ("_qd2geo", "apxq2g", slice(None)),
238
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
239
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
240
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
241
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
242
                                  lat, lon):
243
        """Tests Apex/fortran interface consistency for scalars."""
244
        # Set the input coordinates
245
        self.in_lat = lat
246
        self.in_lon = lon
247
248
        # Get the Apex class method and the fortran function call
249
        apex_func = getattr(self.apex_out, apex_method)
250
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
251
252
        # Get the appropriate input arguments
253
        apex_args = self.get_input_args(apex_method)
254
        fortran_args = self.get_input_args(fortran_method)
255
256
        # Evaluate the equivalent function calls
257
        np.testing.assert_allclose(apex_func(*apex_args),
258
                                   fortran_func(*fortran_args)[fslice])
259
        return
260
261
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
262
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
263
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
264
                              ("_qd2geo", "apxq2g", slice(None)),
265
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
266
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
267
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
268
                                           (180, -180), (-180, 180),
269
                                           (-345, 15), (375, 15)])
270
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
271
                                        fslice, lat, lon1, lon2):
272
        """Tests Apex/fortran interface consistency for longitude rollover."""
273
        # Set the fixed input coordinate
274
        self.in_lat = lat
275
276
        # Get the Apex class method and the fortran function call
277
        apex_func = getattr(self.apex_out, apex_method)
278
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
279
280
        # Get the appropriate input arguments
281
        self.in_lon = lon1
282
        apex_args = self.get_input_args(apex_method)
283
284
        self.in_lon = lon2
285
        fortran_args = self.get_input_args(fortran_method)
286
287
        # Evaluate the equivalent function calls
288
        np.testing.assert_allclose(apex_func(*apex_args),
289
                                   fortran_func(*fortran_args)[fslice])
290
        return
291
292
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
293
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
294
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
295
                              ("_qd2geo", "apxq2g", slice(None)),
296
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
297
    def test_fortran_array_input(self, apex_method, fortran_method, fslice):
298
        """Tests Apex/fortran interface consistency for array input."""
299
        # Get the Apex class method and the fortran function call
300
        apex_func = getattr(self.apex_out, apex_method)
301
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
302
303
        # Set up the input arrays
304
        ref_lat = np.array([0, 30, 60, 90])
305
        ref_alt = np.array([100, 200, 300, 400])
306
        self.in_lat = ref_lat.reshape((2, 2))
307
        self.in_alt = ref_alt.reshape((2, 2))
308
        apex_args = self.get_input_args(apex_method)
309
310
        # Get the Apex class results
311
        aret = apex_func(*apex_args)
312
313
        # Get the fortran function results
314
        flats = list()
315
        flons = list()
316
317
        for i, lat in enumerate(ref_lat):
318
            self.in_lat = lat
319
            self.in_alt = ref_alt[i]
320
            fortran_args = self.get_input_args(fortran_method)
321
            fret = fortran_func(*fortran_args)[fslice]
322
            flats.append(fret[0])
323
            flons.append(fret[1])
324
325
        flats = np.array(flats)
326
        flons = np.array(flons)
327
328
        # Evaluate results
329
        try:
330
            # This returned value is array of floats
331
            np.testing.assert_allclose(aret[0].astype(float),
332
                                       flats.reshape((2, 2)).astype(float))
333
            np.testing.assert_allclose(aret[1].astype(float),
334
                                       flons.reshape((2, 2)).astype(float))
335
        except ValueError:
336
            # This returned value is array of arrays
337
            alats = aret[0].reshape((4,))
338
            alons = aret[1].reshape((4,))
339
            for i, flat in enumerate(flats):
340
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
341
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
342
343
        return
344
345
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
346
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
347
    def test_geo2apexall_scalar(self, lat, lon):
348
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
349
        # Get the Apex and Fortran results
350
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
351
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
352
353
        # Evaluate each element in the results
354
        for aval, fval in zip(aret, fret):
355
            np.testing.assert_allclose(aval, fval)
356
357
    def test_geo2apexall_array(self):
358
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
359
        # Set the input
360
        self.in_lat = np.array([0, 30, 60, 90])
361
        self.in_alt = np.array([100, 200, 300, 400])
362
363
        # Get the Apex class results
364
        aret = self.apex_out._geo2apexall(self.in_lat.reshape((2, 2)),
365
                                          self.in_lon,
366
                                          self.in_alt.reshape((2, 2)))
367
368
        # For each lat/alt pair, get the Fortran results
369
        fret = list()
370
        for i, lat in enumerate(self.in_lat):
371
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
372
                                                    self.in_alt[i], 300, 1))
373
374
        # Cycle through all returned values
375
        for i, ret in enumerate(aret):
376
            try:
377
                # This returned value is array of floats
378
                np.testing.assert_allclose(ret.astype(float),
379
                                           np.array([[fret[0][i], fret[1][i]],
380
                                                     [fret[2][i], fret[3][i]]],
381
                                                    dtype=float))
382
            except ValueError:
383
                # This returned value is array of arrays
384
                ret = ret.reshape((4,))
385
                for j, single_fret in enumerate(fret):
386
                    np.testing.assert_allclose(ret[j], single_fret[i])
387
        return
388
389
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
390
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
391
    def test_convert_consistency(self, in_coord, out_coord):
392
        """Test the self-consistency of the Apex convert method."""
393
        if in_coord == out_coord:
394
            pytest.skip("Test not needed for same src and dest coordinates")
395
396
        # Define the method name
397
        method_name = "2".join([in_coord, out_coord])
398
399
        # Get the method and method inputs
400
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
401
        apex_args = self.get_input_args(method_name)
402
        apex_method = getattr(self.apex_out, method_name)
403
404
        # Define the slice needed to get equivalent output from the named method
405
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
406
407
        # Get output using convert and named method
408
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
409
                                            out_coord, **convert_kwargs)
410
        method_out = apex_method(*apex_args)[mslice]
411
412
        # Compare both outputs, should be identical
413
        np.testing.assert_allclose(convert_out, method_out)
414
        return
415
416
    @pytest.mark.parametrize("bound_lat", [90, -90])
417
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
418
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
419
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
420
        """Test the conversion at the latitude boundary, with allowed excess."""
421
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
422
423
        # Get the two outputs, slight tolerance outside of boundary allowed
424
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
425
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
426
427
        # Test the outputs
428
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
429
        return
430
431
    def test_convert_qd2apex_at_equator(self):
432
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
433
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
434
                                       height=320.0)
435
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
436
                                          dest='apex', height=320.0)
437
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
438
        return
439
440
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
441
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
442
    def test_convert_withnan(self, src, dest):
443
        """Test Apex.convert success with NaN input."""
444
        if src == dest:
445
            pytest.skip("Test not needed for same src and dest coordinates")
446
447
        num_nans = 5
448
        in_loc = np.arange(0, 10, dtype=float)
449
        in_loc[:num_nans] = np.nan
450
451
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
452
453
        for out in out_loc:
454
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
455
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
456
457
        return
458
459
    @pytest.mark.parametrize("bad_lat", [91, -91])
460
    def test_convert_invalid_lat(self, bad_lat):
461
        """Test convert raises ValueError for invalid latitudes."""
462
463
        with pytest.raises(ValueError) as verr:
464
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
465
466
        assert str(verr.value).find("must be in [-90, 90]") > 0
467
        return
468
469
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
470
                                        ("geo", "mlt")])
471
    def test_convert_invalid_transformation(self, coords):
472
        """Test raises NotImplementedError for bad coordinates."""
473
        if "mlt" in coords:
474
            estr = "datetime must be given for MLT calculations"
475
        else:
476
            estr = "Unknown coordinate transformation"
477
478
        with pytest.raises(ValueError) as verr:
479
            self.apex_out.convert(0, 0, *coords)
480
481
        assert str(verr).find(estr) >= 0
482
        return
483
484
    @pytest.mark.parametrize("method_name, out_comp",
485
                             [("geo2apex",
486
                               (55.94841766357422, 94.10684204101562)),
487
                              ("apex2geo",
488
                               (51.476322174072266, -66.22817993164062,
489
                                5.727287771151168e-06)),
490
                              ("geo2qd",
491
                               (56.531288146972656, 94.10684204101562)),
492
                              ("apex2qd", (60.498401178276744, 15.0)),
493
                              ("qd2apex", (59.49138097045895, 15.0))])
494
    def test_method_scalar_input(self, method_name, out_comp):
495
        """Test the user method against set values with scalars."""
496
        # Get the desired methods
497
        user_method = getattr(self.apex_out, method_name)
498
499
        # Get the user output
500
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
501
502
        # Evaluate the user output
503
        np.testing.assert_allclose(user_out, out_comp)
504
505
        for out_val in user_out:
506
            assert np.asarray(out_val).shape == (), "output is not a scalar"
507
        return
508
509
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
510
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
511
    @pytest.mark.parametrize("method_args, out_shape",
512
                             [([[60, 60], 15, 100], (2,)),
513
                              ([60, [15, 15], 100], (2,)),
514
                              ([60, 15, [100, 100]], (2,)),
515
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
516
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
517
                                    out_shape):
518
        """Test the user method with inputs that require some broadcasting."""
519
        if in_coord == out_coord:
520
            pytest.skip("Test not needed for same src and dest coordinates")
521
522
        # Get the desired methods
523
        method_name = "2".join([in_coord, out_coord])
524
        user_method = getattr(self.apex_out, method_name)
525
526
        # Get the user output
527
        user_out = user_method(*method_args)
528
529
        # Evaluate the user output
530
        for out_val in user_out:
531
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
532
            assert out_val.shape == out_shape
533
        return
534
535
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
536
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
537
    @pytest.mark.parametrize("bad_lat", [91, -91])
538
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
539
        """Test convert raises ValueError for invalid latitudes."""
540
        if in_coord == out_coord:
541
            pytest.skip("Test not needed for same src and dest coordinates")
542
543
        # Get the desired methods
544
        method_name = "2".join([in_coord, out_coord])
545
        user_method = getattr(self.apex_out, method_name)
546
547
        with pytest.raises(ValueError) as verr:
548
            user_method(bad_lat, 15, 100)
549
550
        assert str(verr.value).find("must be in [-90, 90]") > 0
551
        return
552
553
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
554
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
555
    @pytest.mark.parametrize("bound_lat", [90, -90])
556
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
557
        """Test user methods at the latitude boundary, with allowed excess."""
558
        if in_coord == out_coord:
559
            pytest.skip("Test not needed for same src and dest coordinates")
560
561
        # Get the desired methods
562
        method_name = "2".join([in_coord, out_coord])
563
        user_method = getattr(self.apex_out, method_name)
564
565
        # Get a latitude just beyond the limit
566
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
567
568
        # Get the two outputs, slight tolerance outside of boundary allowed
569
        bound_out = user_method(bound_lat, 0, 100)
570
        excess_out = user_method(excess_lat, 0, 100)
571
572
        # Test the outputs
573
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
574
        return
575
576
    def test_geo2apex_undefined_warning(self):
577
        """Test geo2apex warning and fill values for an undefined location."""
578
579
        # Update the apex object
580
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
581
582
        # Get the output and the warnings
583
        with warnings.catch_warnings(record=True) as warn_rec:
584
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
585
586
        assert np.isnan(user_lat)
587
        assert np.isfinite(user_lon)
588
        assert len(warn_rec) == 1
589
        assert issubclass(warn_rec[-1].category, UserWarning)
590
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
591
        return
592
593
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
594
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
595
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
596
        """Test quasi-dipole success with a height close to the reference."""
597
        qd_method = getattr(self.apex_out, method_name)
598
        in_args = [0, 15, self.apex_out.refh + delta_h]
599
        out_coords = qd_method(*in_args)
600
601
        for i, out_val in enumerate(out_coords):
602
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
603
        return
604
605
    @pytest.mark.parametrize("method_name, hinc, msg",
606
                             [("apex2qd", 1.0, "is > apex height"),
607
                              ("qd2apex", -1.0, "is < reference height")])
608
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
609
        """Quasi-dipole raises ApexHeightError when height above reference."""
610
        qd_method = getattr(self.apex_out, method_name)
611
612
        with pytest.raises(apexpy.ApexHeightError) as aerr:
613
            qd_method(0, 15, self.apex_out.refh + hinc)
614
615
        assert str(aerr).find(msg) > 0
616
        return
617
618
619
class TestApexMLTMethods():
620
    """Test the Apex Magnetic Local Time (MLT) methods."""
621
    def setup(self):
622
        """Initialize all tests."""
623
        self.apex_out = apexpy.Apex(date=2000, refh=300)
624
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
625
626
    def teardown(self):
627
        """Clean up after each test."""
628
        del self.apex_out, self.in_time
629
630
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
631
    def test_convert_to_mlt(self, in_coord):
632
        """Test the conversions to MLT using Apex convert."""
633
634
        # Get the magnetic longitude from the appropriate method
635
        if in_coord == "geo":
636
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
637
            mlon = apex_method(60, 15, 100)[1]
638
        else:
639
            mlon = 15
640
641
        # Get the output MLT values
642
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
643
                                            height=100, ssheight=2e5,
644
                                            datetime=self.in_time)[1]
645
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
646
647
        # Test the outputs
648
        np.testing.assert_allclose(convert_mlt, method_mlt)
649
        return
650
651
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
652
    def test_convert_mlt_to_lon(self, out_coord):
653
        """Test the conversions from MLT using Apex convert."""
654
        # Get the output longitudes
655
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
656
                                            height=100, ssheight=2e5,
657
                                            datetime=self.in_time,
658
                                            precision=1e-2)
659
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
660
661
        if out_coord == "geo":
662
            method_out = self.apex_out.apex2geo(60, mlon, 100,
663
                                                precision=1e-2)[:-1]
664
        elif out_coord == "qd":
665
            method_out = self.apex_out.apex2qd(60, mlon, 100)
666
        else:
667
            method_out = (60, mlon)
668
669
        # Evaluate the outputs
670
        np.testing.assert_allclose(convert_out, method_out)
671
        return
672
673
    def test_convert_geo2mlt_nodate(self):
674
        """Test convert from geo to MLT raises ValueError with no datetime."""
675
        with pytest.raises(ValueError):
676
            self.apex_out.convert(60, 15, 'geo', 'mlt')
677
        return
678
679
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
680
                             [({}, 23.019629923502603),
681
                              ({"ssheight": 100000}, 23.026712036132814)])
682
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
683
        """Test mlon2mlt with scalar inputs."""
684
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
685
686
        np.testing.assert_allclose(mlt, test_mlt)
687
        assert np.asarray(mlt).shape == ()
688
        return
689
690
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
691
                             [({}, 14.705535888671875),
692
                              ({"ssheight": 100000}, 14.599319458007812)])
693
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
694
        """Test mlt2mlon with scalar inputs."""
695
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
696
697
        np.testing.assert_allclose(mlon, test_mlon)
698
        assert np.asarray(mlon).shape == ()
699
        return
700
701
    @pytest.mark.parametrize("mlon,test_mlt",
702
                             [([0, 180], [23.019261, 11.019261]),
703
                              (np.array([0, 180]), [23.019261, 11.019261]),
704
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
705
                                                      [23.019261, 11.019261]]),
706
                              (range(0, 361, 30),
707
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
708
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
709
                                19.01963, 21.01963, 23.01963])])
710
    def test_mlon2mlt_array(self, mlon, test_mlt):
711
        """Test mlon2mlt with array inputs."""
712
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
713
714
        assert mlt.shape == np.asarray(test_mlt).shape
715
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
716
        return
717
718
    @pytest.mark.parametrize("mlt,test_mlon",
719
                             [([0, 12], [14.705551, 194.705551]),
720
                              (np.array([0, 12]), [14.705551, 194.705551]),
721
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
722
                                                    [14.705551, 194.705551]]),
723
                              (range(0, 25, 2),
724
                               [14.705551, 44.705551, 74.705551, 104.705551,
725
                                134.705551, 164.705551, 194.705551, 224.705551,
726
                                254.705551, 284.705551, 314.705551, 344.705551,
727
                                14.705551])])
728
    def test_mlt2mlon_array(self, mlt, test_mlon):
729
        """Test mlt2mlon with array inputs."""
730
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
731
732
        assert mlon.shape == np.asarray(test_mlon).shape
733
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
734
        return
735
736
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
737
    def test_mlon2mlt_diffdates(self, method_name):
738
        """Test that MLT varies with universal time."""
739
        apex_method = getattr(self.apex_out, method_name)
740
        mlt1 = apex_method(0, self.in_time)
741
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
742
743
        assert mlt1 != mlt2
744
        return
745
746
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
747
    def test_mlon2mlt_offset(self, mlt_offset):
748
        """Test the time wrapping logic for the MLT."""
749
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
750
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
751
                                      self.in_time) + mlt_offset
752
753
        np.testing.assert_allclose(mlt1, mlt2)
754
        return
755
756
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
757
    def test_mlt2mlon_offset(self, mlon_offset):
758
        """Test the time wrapping logic for the magnetic longitude."""
759
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
760
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
761
                                       self.in_time) - mlon_offset
762
763
        np.testing.assert_allclose(mlon1, mlon2)
764
        return
765
766
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
767
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
768
    def test_convert_and_return(self, order, start_val):
769
        """Test the conversion to magnetic longitude or MLT and back again."""
770
        first_method = getattr(self.apex_out, "2".join(order))
771
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
772
773
        middle_val = first_method(start_val, self.in_time)
774
        end_val = second_method(middle_val, self.in_time)
775
776
        np.testing.assert_allclose(start_val, end_val)
777
        return
778
779
780
class TestApexMapMethods():
781
    """Test the Apex height mapping methods."""
782
    def setup(self):
783
        """Initialize all tests."""
784
        self.apex_out = apexpy.Apex(date=2000, refh=300)
785
786
    def teardown(self):
787
        """Clean up after each test."""
788
        del self.apex_out
789
790
    @pytest.mark.parametrize("in_args,test_mapped",
791
                             [([60, 15, 100, 10000],
792
                               [31.841466903686523, 17.916635513305664,
793
                                1.7075473124350538e-6]),
794
                              ([30, 170, 100, 500, False, 1e-2],
795
                               [25.727270126342773, 169.60546875,
796
                                0.00017573432705830783]),
797
                              ([60, 15, 100, 10000, True],
798
                               [-25.424888610839844, 27.310426712036133,
799
                                1.2074182222931995e-6]),
800
                              ([30, 170, 100, 500, True, 1e-2],
801
                               [-13.76642894744873, 164.24259948730469,
802
                                0.00056820799363777041])])
803
    def test_map_to_height(self, in_args, test_mapped):
804
        """Test the map_to_height function."""
805
        mapped = self.apex_out.map_to_height(*in_args)
806
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
807
        return
808
809
    def test_map_to_height_same_height(self):
810
        """Test the map_to_height function when mapping to same height."""
811
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
812
                                             precision=1e-10)
813
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
814
                                   rtol=1e-5)
815
        return
816
817
    @pytest.mark.parametrize('ivec', range(0, 4))
818
    def test_map_to_height_array_location(self, ivec):
819
        """Test map_to_height with array input."""
820
        # Set the base input and output values
821
        in_args = [60, 15, 100, 100]
822
        test_mapped = np.full(shape=(2, 3),
823
                              fill_value=[60, 15.00000381, 0.0]).transpose()
824
825
        # Update inputs for one vectorized value
826
        in_args[ivec] = [in_args[ivec], in_args[ivec]]
827
828
        # Calculate and test function
829
        mapped = self.apex_out.map_to_height(*in_args)
830
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
831
        return
832
833
    @pytest.mark.parametrize("method_name,in_args",
834
                             [("map_to_height", [0, 15, 100, 10000]),
835
                              ("map_E_to_height",
836
                               [0, 15, 100, 10000, [1, 2, 3]]),
837
                              ("map_V_to_height",
838
                               [0, 15, 100, 10000, [1, 2, 3]])])
839
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
840
        """Test map_to_height raises ApexHeightError."""
841
        apex_method = getattr(self.apex_out, method_name)
842
843
        with pytest.raises(apexpy.ApexHeightError) as aerr:
844
            apex_method(*in_args)
845
846
        assert aerr.match("is > apex height")
847
        return
848
849
    @pytest.mark.parametrize("method_name",
850
                             ["map_E_to_height", "map_V_to_height"])
851
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
852
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
853
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
854
        """Test height mapping of E/V with baddly shaped input raises Error."""
855
        apex_method = getattr(self.apex_out, method_name)
856
        in_args = [60, 15, 100, 500, ev_input]
857
        with pytest.raises(ValueError) as verr:
858
            apex_method(*in_args)
859
860
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
861
        return
862
863
    def test_mapping_EV_bad_flag(self):
864
        """Test _map_EV_to_height raises error for bad data type flag."""
865
        with pytest.raises(ValueError) as verr:
866
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
867
868
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
869
        return
870
871
    @pytest.mark.parametrize("in_args,test_mapped",
872
                             [([60, 15, 100, 500, [1, 2, 3]],
873
                               [0.71152183, 2.35624876, 0.57260784]),
874
                              ([60, 15, 100, 500, [2, 3, 4]],
875
                               [1.56028502, 3.43916636, 0.78235384]),
876
                              ([60, 15, 100, 1000, [1, 2, 3]],
877
                               [0.67796492, 2.08982134, 0.55860785]),
878
                              ([60, 15, 200, 500, [1, 2, 3]],
879
                               [0.72377397, 2.42737471, 0.59083726]),
880
                              ([60, 30, 100, 500, [1, 2, 3]],
881
                               [0.68626344, 2.37530133, 0.60060124]),
882
                              ([70, 15, 100, 500, [1, 2, 3]],
883
                               [0.72760378, 2.18082305, 0.29141979])])
884
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
885
        """Test mapping of E-field to a specified height."""
886
        mapped = self.apex_out.map_E_to_height(*in_args)
887
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
888
        return
889
890 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
891
    def test_map_E_to_height_array_location(self, ivec):
892
        """Test mapping of E-field to a specified height with array input."""
893
        # Set the base input and output values
894
        efield = np.array([[1, 2, 3]] * 2).transpose()
895
        in_args = [60, 15, 100, 500, efield]
896
        test_mapped = np.full(shape=(2, 3),
897
                              fill_value=[0.71152183, 2.35624876,
898
                                          0.57260784]).transpose()
899
900
        # Update inputs for one vectorized value if this is a location input
901
        if ivec < 4:
902
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
903
904
        # Get the mapped output and test the results
905
        mapped = self.apex_out.map_E_to_height(*in_args)
906
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
907
        return
908
909
    @pytest.mark.parametrize("in_args,test_mapped",
910
                             [([60, 15, 100, 500, [1, 2, 3]],
911
                               [0.81971957, 2.84512495, 0.69545001]),
912
                              ([60, 15, 100, 500, [2, 3, 4]],
913
                               [1.83027746, 4.14346436, 0.94764179]),
914
                              ([60, 15, 100, 1000, [1, 2, 3]],
915
                               [0.92457698, 3.14997661, 0.85135187]),
916
                              ([60, 15, 200, 500, [1, 2, 3]],
917
                               [0.80388262, 2.79321504, 0.68285158]),
918
                              ([60, 30, 100, 500, [1, 2, 3]],
919
                               [0.76141245, 2.87884673, 0.73655941]),
920
                              ([70, 15, 100, 500, [1, 2, 3]],
921
                               [0.84681866, 2.5925821,  0.34792655])])
922
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
923
        """Test mapping of velocity to a specified height."""
924
        mapped = self.apex_out.map_V_to_height(*in_args)
925
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
926
        return
927
928 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
929
    def test_map_V_to_height_array_location(self, ivec):
930
        """Test mapping of velocity to a specified height with array input."""
931
        # Set the base input and output values
932
        evel = np.array([[1, 2, 3]] * 2).transpose()
933
        in_args = [60, 15, 100, 500, evel]
934
        test_mapped = np.full(shape=(2, 3),
935
                              fill_value=[0.81971957, 2.84512495,
936
                                          0.69545001]).transpose()
937
938
        # Update inputs for one vectorized value if this is a location input
939
        if ivec < 4:
940
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
941
942
        # Get the mapped output and test the results
943
        mapped = self.apex_out.map_V_to_height(*in_args)
944
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
945
        return
946
947
948
class TestApexBasevectorMethods():
949
    """Test the Apex height base vector methods."""
950
    def setup(self):
951
        """Initialize all tests."""
952
        self.apex_out = apexpy.Apex(date=2000, refh=300)
953
        self.lat = 60
954
        self.lon = 15
955
        self.height = 100
956
        self.test_basevec = None
957
958
    def teardown(self):
959
        """Clean up after each test."""
960
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
961
962
    def get_comparison_results(self, bv_coord, coords, precision):
963
        """Get the base vector results using the hidden function for comparison.
964
965
        Parameters
966
        ----------
967
        bv_coord : str
968
            Basevector coordinate scheme, expects on of 'apex', 'qd',
969
            or 'bvectors_apex'
970
        coords : str
971
            Expects one of 'geo', 'apex', or 'qd'
972
        precision : float
973
            Float specifiying precision
974
975
        """
976
        if coords == "geo":
977
            glat = self.lat
978
            glon = self.lon
979
        else:
980
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
981
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
982
                                        precision=precision)
983
984
        if bv_coord == 'qd':
985
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
986
        elif bv_coord == 'apex':
987
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
988
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
989
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
990
        else:
991
            # These are set results that need to be updated with IGRF
992
            if coords == "geo":
993
                self.test_basevec = (
994
                    np.array([4.42368795e-05, 4.42368795e-05]),
995
                    np.array([[0.01047826, 0.01047826],
996
                              [0.33089194, 0.33089194],
997
                              [-1.04941, -1.04941]]),
998
                    np.array([5.3564698e-05, 5.3564698e-05]),
999
                    np.array([[0.00865356, 0.00865356],
1000
                              [0.27327004, 0.27327004],
1001
                              [-0.8666646, -0.8666646]]))
1002
            elif coords == "apex":
1003
                self.test_basevec = (
1004
                    np.array([4.48672735e-05, 4.48672735e-05]),
1005
                    np.array([[-0.12510721, -0.12510721],
1006
                              [0.28945938, 0.28945938],
1007
                              [-1.1505738, -1.1505738]]),
1008
                    np.array([6.38577444e-05, 6.38577444e-05]),
1009
                    np.array([[-0.08790194, -0.08790194],
1010
                              [0.2033779, 0.2033779],
1011
                              [-0.808408, -0.808408]]))
1012
            else:
1013
                self.test_basevec = (
1014
                    np.array([4.46348578e-05, 4.46348578e-05]),
1015
                    np.array([[-0.12642345, -0.12642345],
1016
                              [0.29695055, 0.29695055],
1017
                              [-1.1517885, -1.1517885]]),
1018
                    np.array([6.38626285e-05, 6.38626285e-05]),
1019
                    np.array([[-0.08835986, -0.08835986],
1020
                              [0.20754464, 0.20754464],
1021
                              [-0.8050078, -0.8050078]]))
1022
1023
        return
1024
1025
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1026
    @pytest.mark.parametrize("coords,precision",
1027
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1028
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1029
        """Test the base vector calculations with scalars."""
1030
        # Get the base vectors
1031
        base_method = getattr(self.apex_out,
1032
                              "basevectors_{:s}".format(bv_coord))
1033
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1034
                              precision=precision)
1035
        self.get_comparison_results(bv_coord, coords, precision)
1036
        if bv_coord == "apex":
1037
            basevec = list(basevec)
1038
            for i in range(4):
1039
                # Not able to compare indices 2, 3, 4, and 5
1040
                basevec.pop(2)
1041
1042
        # Test the results
1043
        for i, vec in enumerate(basevec):
1044
            np.testing.assert_allclose(vec, self.test_basevec[i])
1045
        return
1046
1047
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1048
    def test_basevectors_scalar_shape(self, bv_coord):
1049
        """Test the shape of the scalar output."""
1050
        base_method = getattr(self.apex_out,
1051
                              "basevectors_{:s}".format(bv_coord))
1052
        basevec = base_method(self.lat, self.lon, self.height)
1053
1054
        for i, vec in enumerate(basevec):
1055
            if i < 2:
1056
                assert vec.shape == (2,)
1057
            else:
1058
                assert vec.shape == (3,)
1059
        return
1060
1061
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1062
    @pytest.mark.parametrize("ivec", range(3))
1063
    def test_basevectors_array(self, bv_coord, ivec):
1064
        """Test the output shape for array inputs."""
1065
        # Define the input arguments
1066
        in_args = [self.lat, self.lon, self.height]
1067
        in_args[ivec] = [in_args[ivec] for i in range(4)]
1068
1069
        # Get the basevectors
1070
        base_method = getattr(self.apex_out,
1071
                              "basevectors_{:s}".format(bv_coord))
1072
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1073
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1074
        if bv_coord == "apex":
1075
            basevec = list(basevec)
1076
            for i in range(4):
1077
                # Not able to compare indices 2, 3, 4, and 5
1078
                basevec.pop(2)
1079
1080
        # Evaluate the shape and the values
1081
        for i, vec in enumerate(basevec):
1082
            idim = 2 if i < 2 else 3
1083
            assert vec.shape == (idim, 4)
1084
            assert np.all(self.test_basevec[i][0] == vec[0])
1085
            assert np.all(self.test_basevec[i][1] == vec[1])
1086
        return
1087
1088
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1089
    def test_bvectors_apex(self, coords):
1090
        """Test the bvectors_apex method."""
1091
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1092
                   [self.height, self.height]]
1093
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1094
1095
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1096
                                              precision=1e-10)
1097
        for i, vec in enumerate(basevec):
1098
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1099
                                                 decimal=5)
1100
        return
1101
1102
    def test_basevectors_apex_extra_values(self):
1103
        """Test specific values in the apex base vector output."""
1104
        # Set the testing arrays
1105
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1106
                             np.array([0.939012, 0.073416, -0.07342]),
1107
                             np.array([0.055389, 1.004155, 0.257594]),
1108
                             np.array([0, 0, 1.065135])]
1109
1110
        # Get the desired output
1111
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1112
1113
        # Test the values not covered by `test_basevectors_scalar`
1114
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1115
            np.testing.assert_allclose(basevec[ibase],
1116
                                       self.test_basevec[itest], rtol=1e-4)
1117
        return
1118
1119
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1120
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1121
    def test_basevectors_apex_delta(self, lat, lon):
1122
        """Test that vectors are calculated correctly."""
1123
        # Get the apex base vectors and sort them for easy testing
1124
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1125
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1126
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1127
        gvec = [g1, g2, g3]
1128
        dvec = [d1, d2, d3]
1129
        evec = [e1, e2, e3]
1130
1131
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1132
            delta = 1 if idelta == jdelta else 0
1133
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1134
                                       delta, rtol=0, atol=1e-5)
1135
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1136
                                       delta, rtol=0, atol=1e-5)
1137
        return
1138
1139
    def test_basevectors_apex_invalid_scalar(self):
1140
        """Test warning and fill values for base vectors with bad inputs."""
1141
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1142
        invalid = np.full(shape=(3,), fill_value=np.nan)
1143
1144
        # Get the output and the warnings
1145
        with warnings.catch_warnings(record=True) as warn_rec:
1146
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1147
1148
        for i, bvec in enumerate(basevec):
1149
            if i < 2:
1150
                assert not np.allclose(bvec, invalid[:2])
1151
            else:
1152
                np.testing.assert_allclose(bvec, invalid)
1153
1154
        assert issubclass(warn_rec[-1].category, UserWarning)
1155
        assert 'set to NaN where' in str(warn_rec[-1].message)
1156
        return
1157
1158
1159
class TestApexGetMethods():
1160
    """Test the Apex `get` methods."""
1161
    def setup(self):
1162
        """Initialize all tests."""
1163
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1164
1165
    def teardown(self):
1166
        """Clean up after each test."""
1167
        del self.apex_out
1168
1169
    @pytest.mark.parametrize("alat, aheight", [(10, 507.409702543805),
1170
                                               (60, 20313.026999999987)])
1171
    def test_get_apex(self, alat, aheight):
1172
        """Test the apex height retrieval results."""
1173
        alt = self.apex_out.get_apex(alat)
1174
        np.testing.assert_allclose(alt, aheight)
1175
        return
1176
1177
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1178
                             [([80], [100], [300], 5.100682377815247e-05),
1179
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1180
                               np.array([4.18657154e-05, 5.11118114e-05,
1181
                                         4.91969854e-05, 5.10519207e-05,
1182
                                         4.90054816e-05])),
1183
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1184
    def test_get_babs(self, glat, glon, height, test_bmag):
1185
        """Test the method to get the magnitude of the magnetic field."""
1186
        bmag = self.apex_out.get_babs(glat, glon, height)
1187
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1188
        return
1189
1190
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1191
    def test_get_with_invalid_lat(self, bad_lat):
1192
        """Test get methods raise ValueError for invalid latitudes."""
1193
1194
        with pytest.raises(ValueError) as verr:
1195
            self.apex_out.get_apex(bad_lat)
1196
1197
        assert str(verr.value).find("must be in [-90, 90]") > 0
1198
        return
1199
1200
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1201
    def test_get_with_invalid_lat(self, bad_lat):
1202
        """Test get methods raise ValueError for invalid latitudes."""
1203
1204
        with pytest.raises(ValueError) as verr:
1205
            self.apex_out.get_babs(bad_lat, 15, 100)
1206
1207
        assert str(verr.value).find("must be in [-90, 90]") > 0
1208
        return
1209
1210
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1211
    def test_get_at_lat_boundary(self, bound_lat):
1212
        """Test get methods at the latitude boundary, with allowed excess."""
1213
        # Get a latitude just beyond the limit
1214
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1215
1216
        # Get the two outputs, slight tolerance outside of boundary allowed
1217
        bound_out = self.apex_out.get_apex(bound_lat)
1218
        excess_out = self.apex_out.get_apex(excess_lat)
1219
1220
        # Test the outputs
1221
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1222
        return
1223