Completed
Push — develop ( 00c12c...7b256e )
by Angeline
18s queued 12s
created

test_Apex.TestApexInit.test_eq_missing_attr()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import warnings
20
21
import apexpy
22
23
24
@pytest.fixture()
25
def igrf_file():
26
    """A fixture for handling the coefficient file."""
27
    # Ensure the coefficient file exists
28
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
29
                                 'igrf13coeffs.txt')
30
    tmp_file = "temp_coeff.txt"
31
    assert os.path.isfile(original_file)
32
33
    # Move the coefficient file
34
    os.rename(original_file, tmp_file)
35
    yield original_file
36
37
    # Move the coefficient file back
38
    os.rename(tmp_file, original_file)
39
    return
40
41
42
def test_set_epoch_file_error(igrf_file):
43
    """Test raises OSError when IGRF coefficient file is missing."""
44
    # Test missing coefficient file failure
45
    with pytest.raises(OSError) as oerr:
46
        apexpy.Apex()
47
    error_string = "File {:} does not exist".format(igrf_file)
48
    assert str(oerr.value).startswith(error_string)
49
    return
50
51
52
class TestApexInit():
53
    def setup(self):
54
        self.apex_out = None
55
        self.test_date = dt.datetime.utcnow()
56
        self.test_refh = 0
57
        self.bad_file = 'foo/path/to/datafile.blah'
58
59
    def teardown(self):
60
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
61
62
    def eval_date(self):
63
        """Evaluate the times in self.test_date and self.apex_out."""
64
        if isinstance(self.test_date, dt.datetime) \
65
           or isinstance(self.test_date, dt.date):
66
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
67
68
        # Assert the times are the same on the order of tens of seconds.
69
        # Necessary to evaluate the current UTC
70
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
71
        return
72
73
    def eval_refh(self):
74
        """Evaluate the reference height in self.refh and self.apex_out."""
75
        eval_str = "".join(["expected reference height [",
76
                            "{:}] not equal to Apex ".format(self.test_refh),
77
                            "reference height ",
78
                            "[{:}]".format(self.apex_out.refh)])
79
        assert self.test_refh == self.apex_out.refh, eval_str
80
        return
81
82
    def test_init_defaults(self):
83
        """Test Apex class default initialization."""
84
        self.apex_out = apexpy.Apex()
85
        self.eval_date()
86
        self.eval_refh()
87
        return
88
89
    @pytest.mark.parametrize("in_date",
90
                             [2015, 2015.5, dt.date(2015, 1, 1),
91
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
92
    def test_init_date(self, in_date):
93
        """Test Apex class with date initialization."""
94
        self.test_date = in_date
95
        self.apex_out = apexpy.Apex(date=self.test_date)
96
        self.eval_date()
97
        self.eval_refh()
98
        return
99
100
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
101
    def test_set_epoch(self, new_date):
102
        """Test successful setting of Apex epoch after initialization."""
103
        # Evaluate the default initialization
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
108
        # Update the epoch
109
        ref_apex = eval(self.apex_out.__repr__())
110
        self.apex_out.set_epoch(new_date)
111
        assert ref_apex != self.apex_out
112
        self.test_date = new_date
113
        self.eval_date()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization."""
119
        self.test_refh = in_refh
120
        self.apex_out = apexpy.Apex(refh=self.test_refh)
121
        self.eval_date()
122
        self.eval_refh()
123
        return
124
125
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
126
    def test_set_refh(self, new_refh):
127
        """Test the method used to set the reference height after the init."""
128
        # Verify the defaults are set
129
        self.apex_out = apexpy.Apex(date=self.test_date)
130
        self.eval_date()
131
        self.eval_refh()
132
133
        # Update to a new reference height and test
134
        ref_apex = eval(self.apex_out.__repr__())
135
        self.apex_out.set_refh(new_refh)
136
137
        if self.test_refh == new_refh:
138
            assert ref_apex == self.apex_out
139
        else:
140
            assert ref_apex != self.apex_out
141
            self.test_refh = new_refh
142
        self.eval_refh()
143
        return
144
145
    def test_init_with_bad_datafile(self):
146
        """Test raises IOError with non-existent datafile input."""
147
        with pytest.raises(IOError) as oerr:
148
            apexpy.Apex(datafile=self.bad_file)
149
        assert str(oerr.value).startswith('Data file does not exist')
150
        return
151
152
    def test_init_with_bad_fortranlib(self):
153
        """Test raises IOError with non-existent datafile input."""
154
        with pytest.raises(IOError) as oerr:
155
            apexpy.Apex(fortranlib=self.bad_file)
156
        assert str(oerr.value).startswith('Fortran library does not exist')
157
        return
158
159
    def test_repr_eval(self):
160
        """Test the Apex.__repr__ results."""
161
        # Initialize the apex object
162
        self.apex_out = apexpy.Apex()
163
        self.eval_date()
164
        self.eval_refh()
165
166
        # Get and test the repr string
167
        out_str = self.apex_out.__repr__()
168
        assert out_str.find("apexpy.Apex(") == 0
169
170
        # Test the ability to re-create the apex object from the repr string
171
        new_apex = eval(out_str)
172
        assert new_apex == self.apex_out
173
        return
174
175
    def test_ne_other_class(self):
176
        """Test Apex class inequality to a different class."""
177
        self.apex_out = apexpy.Apex()
178
        self.eval_date()
179
        self.eval_refh()
180
181
        assert self.apex_out != self.test_date
182
        return
183
184
    def test_ne_missing_attr(self):
185
        """Test Apex class inequality when attributes are missing from one."""
186
        self.apex_out = apexpy.Apex()
187
        self.eval_date()
188
        self.eval_refh()
189
        ref_apex = eval(self.apex_out.__repr__())
190
        del ref_apex.RE
191
192
        assert ref_apex != self.apex_out
193
        assert self.apex_out != ref_apex
194
        return
195
196
    def test_eq_missing_attr(self):
197
        """Test Apex class equality when attributes are missing from both."""
198
        self.apex_out = apexpy.Apex()
199
        self.eval_date()
200
        self.eval_refh()
201
        ref_apex = eval(self.apex_out.__repr__())
202
        del ref_apex.RE, self.apex_out.RE
203
204
        assert ref_apex == self.apex_out
205
        return
206
207
    def test_str_eval(self):
208
        """Test the Apex.__str__ results."""
209
        # Initialize the apex object
210
        self.apex_out = apexpy.Apex()
211
        self.eval_date()
212
        self.eval_refh()
213
214
        # Get and test the printed string
215
        out_str = self.apex_out.__str__()
216
        assert out_str.find("Decimal year") > 0
217
        return
218
219
220
class TestApexMethod():
221
    """Test the Apex methods."""
222
    def setup(self):
223
        """Initialize all tests."""
224
        self.apex_out = apexpy.Apex(date=2000, refh=300)
225
        self.in_lat = 60
226
        self.in_lon = 15
227
        self.in_alt = 100
228
229
    def teardown(self):
230
        """Clean up after each test."""
231
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
232
233
    def get_input_args(self, method_name, precision=0.0):
234
        """Set the input arguments for the different Apex methods.
235
236
        Parameters
237
        ----------
238
        method_name : str
239
            Name of the Apex class method
240
        precision : float
241
            Value for the precision (default=0.0)
242
243
        Returns
244
        -------
245
        in_args : list
246
            List of the appropriate input arguments
247
248
        """
249
        in_args = [self.in_lat, self.in_lon, self.in_alt]
250
251
        # Add precision, if needed
252
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
253
                           "_apex2geo"]:
254
            in_args.append(precision)
255
256
        # Add a reference height, if needed
257
        if method_name in ["apxg2all"]:
258
            in_args.append(300)
259
260
        # Add a vector flag, if needed
261
        if method_name in ["apxg2all", "apxg2q"]:
262
            in_args.append(1)
263
264
        return in_args
265
266
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
267
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
268
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
269
                              ("_qd2geo", "apxq2g", slice(None)),
270
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
271
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
272
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
273
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
274
                                  lat, lon):
275
        """Tests Apex/fortran interface consistency for scalars."""
276
        # Set the input coordinates
277
        self.in_lat = lat
278
        self.in_lon = lon
279
280
        # Get the Apex class method and the fortran function call
281
        apex_func = getattr(self.apex_out, apex_method)
282
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
283
284
        # Get the appropriate input arguments
285
        apex_args = self.get_input_args(apex_method)
286
        fortran_args = self.get_input_args(fortran_method)
287
288
        # Evaluate the equivalent function calls
289
        np.testing.assert_allclose(apex_func(*apex_args),
290
                                   fortran_func(*fortran_args)[fslice])
291
        return
292
293
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
294
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
295
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
296
                              ("_qd2geo", "apxq2g", slice(None)),
297
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
298
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
299
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
300
                                           (180, -180), (-180, 180),
301
                                           (-345, 15), (375, 15)])
302
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
303
                                        fslice, lat, lon1, lon2):
304
        """Tests Apex/fortran interface consistency for longitude rollover."""
305
        # Set the fixed input coordinate
306
        self.in_lat = lat
307
308
        # Get the Apex class method and the fortran function call
309
        apex_func = getattr(self.apex_out, apex_method)
310
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
311
312
        # Get the appropriate input arguments
313
        self.in_lon = lon1
314
        apex_args = self.get_input_args(apex_method)
315
316
        self.in_lon = lon2
317
        fortran_args = self.get_input_args(fortran_method)
318
319
        # Evaluate the equivalent function calls
320
        np.testing.assert_allclose(apex_func(*apex_args),
321
                                   fortran_func(*fortran_args)[fslice])
322
        return
323
324
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
325
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
326
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
327
                              ("_qd2geo", "apxq2g", slice(None)),
328
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
329
    def test_fortran_array_input(self, apex_method, fortran_method, fslice):
330
        """Tests Apex/fortran interface consistency for array input."""
331
        # Get the Apex class method and the fortran function call
332
        apex_func = getattr(self.apex_out, apex_method)
333
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
334
335
        # Set up the input arrays
336
        ref_lat = np.array([0, 30, 60, 90])
337
        ref_alt = np.array([100, 200, 300, 400])
338
        self.in_lat = ref_lat.reshape((2, 2))
339
        self.in_alt = ref_alt.reshape((2, 2))
340
        apex_args = self.get_input_args(apex_method)
341
342
        # Get the Apex class results
343
        aret = apex_func(*apex_args)
344
345
        # Get the fortran function results
346
        flats = list()
347
        flons = list()
348
349
        for i, lat in enumerate(ref_lat):
350
            self.in_lat = lat
351
            self.in_alt = ref_alt[i]
352
            fortran_args = self.get_input_args(fortran_method)
353
            fret = fortran_func(*fortran_args)[fslice]
354
            flats.append(fret[0])
355
            flons.append(fret[1])
356
357
        flats = np.array(flats)
358
        flons = np.array(flons)
359
360
        # Evaluate results
361
        try:
362
            # This returned value is array of floats
363
            np.testing.assert_allclose(aret[0].astype(float),
364
                                       flats.reshape((2, 2)).astype(float))
365
            np.testing.assert_allclose(aret[1].astype(float),
366
                                       flons.reshape((2, 2)).astype(float))
367
        except ValueError:
368
            # This returned value is array of arrays
369
            alats = aret[0].reshape((4,))
370
            alons = aret[1].reshape((4,))
371
            for i, flat in enumerate(flats):
372
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
373
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
374
375
        return
376
377
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
378
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
379
    def test_geo2apexall_scalar(self, lat, lon):
380
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
381
        # Get the Apex and Fortran results
382
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
383
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
384
385
        # Evaluate each element in the results
386
        for aval, fval in zip(aret, fret):
387
            np.testing.assert_allclose(aval, fval)
388
389
    def test_geo2apexall_array(self):
390
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
391
        # Set the input
392
        self.in_lat = np.array([0, 30, 60, 90])
393
        self.in_alt = np.array([100, 200, 300, 400])
394
395
        # Get the Apex class results
396
        aret = self.apex_out._geo2apexall(self.in_lat.reshape((2, 2)),
397
                                          self.in_lon,
398
                                          self.in_alt.reshape((2, 2)))
399
400
        # For each lat/alt pair, get the Fortran results
401
        fret = list()
402
        for i, lat in enumerate(self.in_lat):
403
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
404
                                                    self.in_alt[i], 300, 1))
405
406
        # Cycle through all returned values
407
        for i, ret in enumerate(aret):
408
            try:
409
                # This returned value is array of floats
410
                np.testing.assert_allclose(ret.astype(float),
411
                                           np.array([[fret[0][i], fret[1][i]],
412
                                                     [fret[2][i], fret[3][i]]],
413
                                                    dtype=float))
414
            except ValueError:
415
                # This returned value is array of arrays
416
                ret = ret.reshape((4,))
417
                for j, single_fret in enumerate(fret):
418
                    np.testing.assert_allclose(ret[j], single_fret[i])
419
        return
420
421
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
422
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
423
    def test_convert_consistency(self, in_coord, out_coord):
424
        """Test the self-consistency of the Apex convert method."""
425
        if in_coord == out_coord:
426
            pytest.skip("Test not needed for same src and dest coordinates")
427
428
        # Define the method name
429
        method_name = "2".join([in_coord, out_coord])
430
431
        # Get the method and method inputs
432
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
433
        apex_args = self.get_input_args(method_name)
434
        apex_method = getattr(self.apex_out, method_name)
435
436
        # Define the slice needed to get equivalent output from the named method
437
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
438
439
        # Get output using convert and named method
440
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
441
                                            out_coord, **convert_kwargs)
442
        method_out = apex_method(*apex_args)[mslice]
443
444
        # Compare both outputs, should be identical
445
        np.testing.assert_allclose(convert_out, method_out)
446
        return
447
448
    @pytest.mark.parametrize("bound_lat", [90, -90])
449
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
450
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
451
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
452
        """Test the conversion at the latitude boundary, with allowed excess."""
453
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
454
455
        # Get the two outputs, slight tolerance outside of boundary allowed
456
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
457
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
458
459
        # Test the outputs
460
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
461
        return
462
463
    def test_convert_qd2apex_at_equator(self):
464
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
465
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
466
                                       height=320.0)
467
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
468
                                          dest='apex', height=320.0)
469
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
470
        return
471
472
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
473
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
474
    def test_convert_withnan(self, src, dest):
475
        """Test Apex.convert success with NaN input."""
476
        if src == dest:
477
            pytest.skip("Test not needed for same src and dest coordinates")
478
479
        num_nans = 5
480
        in_loc = np.arange(0, 10, dtype=float)
481
        in_loc[:num_nans] = np.nan
482
483
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
484
485
        for out in out_loc:
486
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
487
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
488
489
        return
490
491
    @pytest.mark.parametrize("bad_lat", [91, -91])
492
    def test_convert_invalid_lat(self, bad_lat):
493
        """Test convert raises ValueError for invalid latitudes."""
494
495
        with pytest.raises(ValueError) as verr:
496
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
497
498
        assert str(verr.value).find("must be in [-90, 90]") > 0
499
        return
500
501
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
502
                                        ("geo", "mlt")])
503
    def test_convert_invalid_transformation(self, coords):
504
        """Test raises NotImplementedError for bad coordinates."""
505
        if "mlt" in coords:
506
            estr = "datetime must be given for MLT calculations"
507
        else:
508
            estr = "Unknown coordinate transformation"
509
510
        with pytest.raises(ValueError) as verr:
511
            self.apex_out.convert(0, 0, *coords)
512
513
        assert str(verr).find(estr) >= 0
514
        return
515
516
    @pytest.mark.parametrize("method_name, out_comp",
517
                             [("geo2apex",
518
                               (55.94841766357422, 94.10684204101562)),
519
                              ("apex2geo",
520
                               (51.476322174072266, -66.22817993164062,
521
                                5.727287771151168e-06)),
522
                              ("geo2qd",
523
                               (56.531288146972656, 94.10684204101562)),
524
                              ("apex2qd", (60.498401178276744, 15.0)),
525
                              ("qd2apex", (59.49138097045895, 15.0))])
526
    def test_method_scalar_input(self, method_name, out_comp):
527
        """Test the user method against set values with scalars."""
528
        # Get the desired methods
529
        user_method = getattr(self.apex_out, method_name)
530
531
        # Get the user output
532
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
533
534
        # Evaluate the user output
535
        np.testing.assert_allclose(user_out, out_comp)
536
537
        for out_val in user_out:
538
            assert np.asarray(out_val).shape == (), "output is not a scalar"
539
        return
540
541
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
542
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
543
    @pytest.mark.parametrize("method_args, out_shape",
544
                             [([[60, 60], 15, 100], (2,)),
545
                              ([60, [15, 15], 100], (2,)),
546
                              ([60, 15, [100, 100]], (2,)),
547
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
548
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
549
                                    out_shape):
550
        """Test the user method with inputs that require some broadcasting."""
551
        if in_coord == out_coord:
552
            pytest.skip("Test not needed for same src and dest coordinates")
553
554
        # Get the desired methods
555
        method_name = "2".join([in_coord, out_coord])
556
        user_method = getattr(self.apex_out, method_name)
557
558
        # Get the user output
559
        user_out = user_method(*method_args)
560
561
        # Evaluate the user output
562
        for out_val in user_out:
563
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
564
            assert out_val.shape == out_shape
565
        return
566
567
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
568
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
569
    @pytest.mark.parametrize("bad_lat", [91, -91])
570
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
571
        """Test convert raises ValueError for invalid latitudes."""
572
        if in_coord == out_coord:
573
            pytest.skip("Test not needed for same src and dest coordinates")
574
575
        # Get the desired methods
576
        method_name = "2".join([in_coord, out_coord])
577
        user_method = getattr(self.apex_out, method_name)
578
579
        with pytest.raises(ValueError) as verr:
580
            user_method(bad_lat, 15, 100)
581
582
        assert str(verr.value).find("must be in [-90, 90]") > 0
583
        return
584
585
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
586
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
587
    @pytest.mark.parametrize("bound_lat", [90, -90])
588
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
589
        """Test user methods at the latitude boundary, with allowed excess."""
590
        if in_coord == out_coord:
591
            pytest.skip("Test not needed for same src and dest coordinates")
592
593
        # Get the desired methods
594
        method_name = "2".join([in_coord, out_coord])
595
        user_method = getattr(self.apex_out, method_name)
596
597
        # Get a latitude just beyond the limit
598
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
599
600
        # Get the two outputs, slight tolerance outside of boundary allowed
601
        bound_out = user_method(bound_lat, 0, 100)
602
        excess_out = user_method(excess_lat, 0, 100)
603
604
        # Test the outputs
605
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
606
        return
607
608
    def test_geo2apex_undefined_warning(self):
609
        """Test geo2apex warning and fill values for an undefined location."""
610
611
        # Update the apex object
612
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
613
614
        # Get the output and the warnings
615
        with warnings.catch_warnings(record=True) as warn_rec:
616
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
617
618
        assert np.isnan(user_lat)
619
        assert np.isfinite(user_lon)
620
        assert len(warn_rec) == 1
621
        assert issubclass(warn_rec[-1].category, UserWarning)
622
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
623
        return
624
625
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
626
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
627
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
628
        """Test quasi-dipole success with a height close to the reference."""
629
        qd_method = getattr(self.apex_out, method_name)
630
        in_args = [0, 15, self.apex_out.refh + delta_h]
631
        out_coords = qd_method(*in_args)
632
633
        for i, out_val in enumerate(out_coords):
634
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
635
        return
636
637
    @pytest.mark.parametrize("method_name, hinc, msg",
638
                             [("apex2qd", 1.0, "is > apex height"),
639
                              ("qd2apex", -1.0, "is < reference height")])
640
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
641
        """Quasi-dipole raises ApexHeightError when height above reference."""
642
        qd_method = getattr(self.apex_out, method_name)
643
644
        with pytest.raises(apexpy.ApexHeightError) as aerr:
645
            qd_method(0, 15, self.apex_out.refh + hinc)
646
647
        assert str(aerr).find(msg) > 0
648
        return
649
650
651
class TestApexMLTMethods():
652
    """Test the Apex Magnetic Local Time (MLT) methods."""
653
    def setup(self):
654
        """Initialize all tests."""
655
        self.apex_out = apexpy.Apex(date=2000, refh=300)
656
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
657
658
    def teardown(self):
659
        """Clean up after each test."""
660
        del self.apex_out, self.in_time
661
662
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
663
    def test_convert_to_mlt(self, in_coord):
664
        """Test the conversions to MLT using Apex convert."""
665
666
        # Get the magnetic longitude from the appropriate method
667
        if in_coord == "geo":
668
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
669
            mlon = apex_method(60, 15, 100)[1]
670
        else:
671
            mlon = 15
672
673
        # Get the output MLT values
674
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
675
                                            height=100, ssheight=2e5,
676
                                            datetime=self.in_time)[1]
677
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
678
679
        # Test the outputs
680
        np.testing.assert_allclose(convert_mlt, method_mlt)
681
        return
682
683
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
684
    def test_convert_mlt_to_lon(self, out_coord):
685
        """Test the conversions from MLT using Apex convert."""
686
        # Get the output longitudes
687
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
688
                                            height=100, ssheight=2e5,
689
                                            datetime=self.in_time,
690
                                            precision=1e-2)
691
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
692
693
        if out_coord == "geo":
694
            method_out = self.apex_out.apex2geo(60, mlon, 100,
695
                                                precision=1e-2)[:-1]
696
        elif out_coord == "qd":
697
            method_out = self.apex_out.apex2qd(60, mlon, 100)
698
        else:
699
            method_out = (60, mlon)
700
701
        # Evaluate the outputs
702
        np.testing.assert_allclose(convert_out, method_out)
703
        return
704
705
    def test_convert_geo2mlt_nodate(self):
706
        """Test convert from geo to MLT raises ValueError with no datetime."""
707
        with pytest.raises(ValueError):
708
            self.apex_out.convert(60, 15, 'geo', 'mlt')
709
        return
710
711
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
712
                             [({}, 23.019629923502603),
713
                              ({"ssheight": 100000}, 23.026712036132814)])
714
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
715
        """Test mlon2mlt with scalar inputs."""
716
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
717
718
        np.testing.assert_allclose(mlt, test_mlt)
719
        assert np.asarray(mlt).shape == ()
720
        return
721
722
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
723
                             [({}, 14.705535888671875),
724
                              ({"ssheight": 100000}, 14.599319458007812)])
725
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
726
        """Test mlt2mlon with scalar inputs."""
727
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
728
729
        np.testing.assert_allclose(mlon, test_mlon)
730
        assert np.asarray(mlon).shape == ()
731
        return
732
733
    @pytest.mark.parametrize("mlon,test_mlt",
734
                             [([0, 180], [23.019261, 11.019261]),
735
                              (np.array([0, 180]), [23.019261, 11.019261]),
736
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
737
                                                      [23.019261, 11.019261]]),
738
                              (range(0, 361, 30),
739
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
740
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
741
                                19.01963, 21.01963, 23.01963])])
742
    def test_mlon2mlt_array(self, mlon, test_mlt):
743
        """Test mlon2mlt with array inputs."""
744
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
745
746
        assert mlt.shape == np.asarray(test_mlt).shape
747
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
748
        return
749
750
    @pytest.mark.parametrize("mlt,test_mlon",
751
                             [([0, 12], [14.705551, 194.705551]),
752
                              (np.array([0, 12]), [14.705551, 194.705551]),
753
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
754
                                                    [14.705551, 194.705551]]),
755
                              (range(0, 25, 2),
756
                               [14.705551, 44.705551, 74.705551, 104.705551,
757
                                134.705551, 164.705551, 194.705551, 224.705551,
758
                                254.705551, 284.705551, 314.705551, 344.705551,
759
                                14.705551])])
760
    def test_mlt2mlon_array(self, mlt, test_mlon):
761
        """Test mlt2mlon with array inputs."""
762
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
763
764
        assert mlon.shape == np.asarray(test_mlon).shape
765
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
766
        return
767
768
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
769
    def test_mlon2mlt_diffdates(self, method_name):
770
        """Test that MLT varies with universal time."""
771
        apex_method = getattr(self.apex_out, method_name)
772
        mlt1 = apex_method(0, self.in_time)
773
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
774
775
        assert mlt1 != mlt2
776
        return
777
778
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
779
    def test_mlon2mlt_offset(self, mlt_offset):
780
        """Test the time wrapping logic for the MLT."""
781
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
782
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
783
                                      self.in_time) + mlt_offset
784
785
        np.testing.assert_allclose(mlt1, mlt2)
786
        return
787
788
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
789
    def test_mlt2mlon_offset(self, mlon_offset):
790
        """Test the time wrapping logic for the magnetic longitude."""
791
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
792
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
793
                                       self.in_time) - mlon_offset
794
795
        np.testing.assert_allclose(mlon1, mlon2)
796
        return
797
798
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
799
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
800
    def test_convert_and_return(self, order, start_val):
801
        """Test the conversion to magnetic longitude or MLT and back again."""
802
        first_method = getattr(self.apex_out, "2".join(order))
803
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
804
805
        middle_val = first_method(start_val, self.in_time)
806
        end_val = second_method(middle_val, self.in_time)
807
808
        np.testing.assert_allclose(start_val, end_val)
809
        return
810
811
812
class TestApexMapMethods():
813
    """Test the Apex height mapping methods."""
814
    def setup(self):
815
        """Initialize all tests."""
816
        self.apex_out = apexpy.Apex(date=2000, refh=300)
817
818
    def teardown(self):
819
        """Clean up after each test."""
820
        del self.apex_out
821
822
    @pytest.mark.parametrize("in_args,test_mapped",
823
                             [([60, 15, 100, 10000],
824
                               [31.841466903686523, 17.916635513305664,
825
                                1.7075473124350538e-6]),
826
                              ([30, 170, 100, 500, False, 1e-2],
827
                               [25.727270126342773, 169.60546875,
828
                                0.00017573432705830783]),
829
                              ([60, 15, 100, 10000, True],
830
                               [-25.424888610839844, 27.310426712036133,
831
                                1.2074182222931995e-6]),
832
                              ([30, 170, 100, 500, True, 1e-2],
833
                               [-13.76642894744873, 164.24259948730469,
834
                                0.00056820799363777041])])
835
    def test_map_to_height(self, in_args, test_mapped):
836
        """Test the map_to_height function."""
837
        mapped = self.apex_out.map_to_height(*in_args)
838
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
839
        return
840
841
    def test_map_to_height_same_height(self):
842
        """Test the map_to_height function when mapping to same height."""
843
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
844
                                             precision=1e-10)
845
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
846
                                   rtol=1e-5)
847
        return
848
849
    @pytest.mark.parametrize('ivec', range(0, 4))
850
    def test_map_to_height_array_location(self, ivec):
851
        """Test map_to_height with array input."""
852
        # Set the base input and output values
853
        in_args = [60, 15, 100, 100]
854
        test_mapped = np.full(shape=(2, 3),
855
                              fill_value=[60, 15.00000381, 0.0]).transpose()
856
857
        # Update inputs for one vectorized value
858
        in_args[ivec] = [in_args[ivec], in_args[ivec]]
859
860
        # Calculate and test function
861
        mapped = self.apex_out.map_to_height(*in_args)
862
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
863
        return
864
865
    @pytest.mark.parametrize("method_name,in_args",
866
                             [("map_to_height", [0, 15, 100, 10000]),
867
                              ("map_E_to_height",
868
                               [0, 15, 100, 10000, [1, 2, 3]]),
869
                              ("map_V_to_height",
870
                               [0, 15, 100, 10000, [1, 2, 3]])])
871
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
872
        """Test map_to_height raises ApexHeightError."""
873
        apex_method = getattr(self.apex_out, method_name)
874
875
        with pytest.raises(apexpy.ApexHeightError) as aerr:
876
            apex_method(*in_args)
877
878
        assert aerr.match("is > apex height")
879
        return
880
881
    @pytest.mark.parametrize("method_name",
882
                             ["map_E_to_height", "map_V_to_height"])
883
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
884
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
885
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
886
        """Test height mapping of E/V with baddly shaped input raises Error."""
887
        apex_method = getattr(self.apex_out, method_name)
888
        in_args = [60, 15, 100, 500, ev_input]
889
        with pytest.raises(ValueError) as verr:
890
            apex_method(*in_args)
891
892
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
893
        return
894
895
    def test_mapping_EV_bad_flag(self):
896
        """Test _map_EV_to_height raises error for bad data type flag."""
897
        with pytest.raises(ValueError) as verr:
898
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
899
900
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
901
        return
902
903
    @pytest.mark.parametrize("in_args,test_mapped",
904
                             [([60, 15, 100, 500, [1, 2, 3]],
905
                               [0.71152183, 2.35624876, 0.57260784]),
906
                              ([60, 15, 100, 500, [2, 3, 4]],
907
                               [1.56028502, 3.43916636, 0.78235384]),
908
                              ([60, 15, 100, 1000, [1, 2, 3]],
909
                               [0.67796492, 2.08982134, 0.55860785]),
910
                              ([60, 15, 200, 500, [1, 2, 3]],
911
                               [0.72377397, 2.42737471, 0.59083726]),
912
                              ([60, 30, 100, 500, [1, 2, 3]],
913
                               [0.68626344, 2.37530133, 0.60060124]),
914
                              ([70, 15, 100, 500, [1, 2, 3]],
915
                               [0.72760378, 2.18082305, 0.29141979])])
916
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
917
        """Test mapping of E-field to a specified height."""
918
        mapped = self.apex_out.map_E_to_height(*in_args)
919
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
920
        return
921
922 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
923
    def test_map_E_to_height_array_location(self, ivec):
924
        """Test mapping of E-field to a specified height with array input."""
925
        # Set the base input and output values
926
        efield = np.array([[1, 2, 3]] * 2).transpose()
927
        in_args = [60, 15, 100, 500, efield]
928
        test_mapped = np.full(shape=(2, 3),
929
                              fill_value=[0.71152183, 2.35624876,
930
                                          0.57260784]).transpose()
931
932
        # Update inputs for one vectorized value if this is a location input
933
        if ivec < 4:
934
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
935
936
        # Get the mapped output and test the results
937
        mapped = self.apex_out.map_E_to_height(*in_args)
938
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
939
        return
940
941
    @pytest.mark.parametrize("in_args,test_mapped",
942
                             [([60, 15, 100, 500, [1, 2, 3]],
943
                               [0.81971957, 2.84512495, 0.69545001]),
944
                              ([60, 15, 100, 500, [2, 3, 4]],
945
                               [1.83027746, 4.14346436, 0.94764179]),
946
                              ([60, 15, 100, 1000, [1, 2, 3]],
947
                               [0.92457698, 3.14997661, 0.85135187]),
948
                              ([60, 15, 200, 500, [1, 2, 3]],
949
                               [0.80388262, 2.79321504, 0.68285158]),
950
                              ([60, 30, 100, 500, [1, 2, 3]],
951
                               [0.76141245, 2.87884673, 0.73655941]),
952
                              ([70, 15, 100, 500, [1, 2, 3]],
953
                               [0.84681866, 2.5925821,  0.34792655])])
954
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
955
        """Test mapping of velocity to a specified height."""
956
        mapped = self.apex_out.map_V_to_height(*in_args)
957
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
958
        return
959
960 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
961
    def test_map_V_to_height_array_location(self, ivec):
962
        """Test mapping of velocity to a specified height with array input."""
963
        # Set the base input and output values
964
        evel = np.array([[1, 2, 3]] * 2).transpose()
965
        in_args = [60, 15, 100, 500, evel]
966
        test_mapped = np.full(shape=(2, 3),
967
                              fill_value=[0.81971957, 2.84512495,
968
                                          0.69545001]).transpose()
969
970
        # Update inputs for one vectorized value if this is a location input
971
        if ivec < 4:
972
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
973
974
        # Get the mapped output and test the results
975
        mapped = self.apex_out.map_V_to_height(*in_args)
976
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
977
        return
978
979
980
class TestApexBasevectorMethods():
981
    """Test the Apex height base vector methods."""
982
    def setup(self):
983
        """Initialize all tests."""
984
        self.apex_out = apexpy.Apex(date=2000, refh=300)
985
        self.lat = 60
986
        self.lon = 15
987
        self.height = 100
988
        self.test_basevec = None
989
990
    def teardown(self):
991
        """Clean up after each test."""
992
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
993
994
    def get_comparison_results(self, bv_coord, coords, precision):
995
        """Get the base vector results using the hidden function for comparison.
996
997
        Parameters
998
        ----------
999
        bv_coord : str
1000
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1001
            or 'bvectors_apex'
1002
        coords : str
1003
            Expects one of 'geo', 'apex', or 'qd'
1004
        precision : float
1005
            Float specifiying precision
1006
1007
        """
1008
        if coords == "geo":
1009
            glat = self.lat
1010
            glon = self.lon
1011
        else:
1012
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1013
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1014
                                        precision=precision)
1015
1016
        if bv_coord == 'qd':
1017
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1018
        elif bv_coord == 'apex':
1019
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1020
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1021
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1022
        else:
1023
            # These are set results that need to be updated with IGRF
1024
            if coords == "geo":
1025
                self.test_basevec = (
1026
                    np.array([4.42368795e-05, 4.42368795e-05]),
1027
                    np.array([[0.01047826, 0.01047826],
1028
                              [0.33089194, 0.33089194],
1029
                              [-1.04941, -1.04941]]),
1030
                    np.array([5.3564698e-05, 5.3564698e-05]),
1031
                    np.array([[0.00865356, 0.00865356],
1032
                              [0.27327004, 0.27327004],
1033
                              [-0.8666646, -0.8666646]]))
1034
            elif coords == "apex":
1035
                self.test_basevec = (
1036
                    np.array([4.48672735e-05, 4.48672735e-05]),
1037
                    np.array([[-0.12510721, -0.12510721],
1038
                              [0.28945938, 0.28945938],
1039
                              [-1.1505738, -1.1505738]]),
1040
                    np.array([6.38577444e-05, 6.38577444e-05]),
1041
                    np.array([[-0.08790194, -0.08790194],
1042
                              [0.2033779, 0.2033779],
1043
                              [-0.808408, -0.808408]]))
1044
            else:
1045
                self.test_basevec = (
1046
                    np.array([4.46348578e-05, 4.46348578e-05]),
1047
                    np.array([[-0.12642345, -0.12642345],
1048
                              [0.29695055, 0.29695055],
1049
                              [-1.1517885, -1.1517885]]),
1050
                    np.array([6.38626285e-05, 6.38626285e-05]),
1051
                    np.array([[-0.08835986, -0.08835986],
1052
                              [0.20754464, 0.20754464],
1053
                              [-0.8050078, -0.8050078]]))
1054
1055
        return
1056
1057
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1058
    @pytest.mark.parametrize("coords,precision",
1059
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1060
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1061
        """Test the base vector calculations with scalars."""
1062
        # Get the base vectors
1063
        base_method = getattr(self.apex_out,
1064
                              "basevectors_{:s}".format(bv_coord))
1065
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1066
                              precision=precision)
1067
        self.get_comparison_results(bv_coord, coords, precision)
1068
        if bv_coord == "apex":
1069
            basevec = list(basevec)
1070
            for i in range(4):
1071
                # Not able to compare indices 2, 3, 4, and 5
1072
                basevec.pop(2)
1073
1074
        # Test the results
1075
        for i, vec in enumerate(basevec):
1076
            np.testing.assert_allclose(vec, self.test_basevec[i])
1077
        return
1078
1079
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1080
    def test_basevectors_scalar_shape(self, bv_coord):
1081
        """Test the shape of the scalar output."""
1082
        base_method = getattr(self.apex_out,
1083
                              "basevectors_{:s}".format(bv_coord))
1084
        basevec = base_method(self.lat, self.lon, self.height)
1085
1086
        for i, vec in enumerate(basevec):
1087
            if i < 2:
1088
                assert vec.shape == (2,)
1089
            else:
1090
                assert vec.shape == (3,)
1091
        return
1092
1093
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1094
    @pytest.mark.parametrize("ivec", range(3))
1095
    def test_basevectors_array(self, bv_coord, ivec):
1096
        """Test the output shape for array inputs."""
1097
        # Define the input arguments
1098
        in_args = [self.lat, self.lon, self.height]
1099
        in_args[ivec] = [in_args[ivec] for i in range(4)]
1100
1101
        # Get the basevectors
1102
        base_method = getattr(self.apex_out,
1103
                              "basevectors_{:s}".format(bv_coord))
1104
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1105
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1106
        if bv_coord == "apex":
1107
            basevec = list(basevec)
1108
            for i in range(4):
1109
                # Not able to compare indices 2, 3, 4, and 5
1110
                basevec.pop(2)
1111
1112
        # Evaluate the shape and the values
1113
        for i, vec in enumerate(basevec):
1114
            idim = 2 if i < 2 else 3
1115
            assert vec.shape == (idim, 4)
1116
            assert np.all(self.test_basevec[i][0] == vec[0])
1117
            assert np.all(self.test_basevec[i][1] == vec[1])
1118
        return
1119
1120
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1121
    def test_bvectors_apex(self, coords):
1122
        """Test the bvectors_apex method."""
1123
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1124
                   [self.height, self.height]]
1125
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1126
1127
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1128
                                              precision=1e-10)
1129
        for i, vec in enumerate(basevec):
1130
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1131
                                                 decimal=5)
1132
        return
1133
1134
    def test_basevectors_apex_extra_values(self):
1135
        """Test specific values in the apex base vector output."""
1136
        # Set the testing arrays
1137
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1138
                             np.array([0.939012, 0.073416, -0.07342]),
1139
                             np.array([0.055389, 1.004155, 0.257594]),
1140
                             np.array([0, 0, 1.065135])]
1141
1142
        # Get the desired output
1143
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1144
1145
        # Test the values not covered by `test_basevectors_scalar`
1146
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1147
            np.testing.assert_allclose(basevec[ibase],
1148
                                       self.test_basevec[itest], rtol=1e-4)
1149
        return
1150
1151
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1152
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1153
    def test_basevectors_apex_delta(self, lat, lon):
1154
        """Test that vectors are calculated correctly."""
1155
        # Get the apex base vectors and sort them for easy testing
1156
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1157
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1158
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1159
        gvec = [g1, g2, g3]
1160
        dvec = [d1, d2, d3]
1161
        evec = [e1, e2, e3]
1162
1163
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1164
            delta = 1 if idelta == jdelta else 0
1165
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1166
                                       delta, rtol=0, atol=1e-5)
1167
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1168
                                       delta, rtol=0, atol=1e-5)
1169
        return
1170
1171
    def test_basevectors_apex_invalid_scalar(self):
1172
        """Test warning and fill values for base vectors with bad inputs."""
1173
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1174
        invalid = np.full(shape=(3,), fill_value=np.nan)
1175
1176
        # Get the output and the warnings
1177
        with warnings.catch_warnings(record=True) as warn_rec:
1178
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1179
1180
        for i, bvec in enumerate(basevec):
1181
            if i < 2:
1182
                assert not np.allclose(bvec, invalid[:2])
1183
            else:
1184
                np.testing.assert_allclose(bvec, invalid)
1185
1186
        assert issubclass(warn_rec[-1].category, UserWarning)
1187
        assert 'set to NaN where' in str(warn_rec[-1].message)
1188
        return
1189
1190
1191
class TestApexGetMethods():
1192
    """Test the Apex `get` methods."""
1193
    def setup(self):
1194
        """Initialize all tests."""
1195
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1196
1197
    def teardown(self):
1198
        """Clean up after each test."""
1199
        del self.apex_out
1200
1201
    @pytest.mark.parametrize("alat, aheight", [(10, 507.409702543805),
1202
                                               (60, 20313.026999999987)])
1203
    def test_get_apex(self, alat, aheight):
1204
        """Test the apex height retrieval results."""
1205
        alt = self.apex_out.get_apex(alat)
1206
        np.testing.assert_allclose(alt, aheight)
1207
        return
1208
1209
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1210
                             [([80], [100], [300], 5.100682377815247e-05),
1211
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1212
                               np.array([4.18657154e-05, 5.11118114e-05,
1213
                                         4.91969854e-05, 5.10519207e-05,
1214
                                         4.90054816e-05])),
1215
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1216
    def test_get_babs(self, glat, glon, height, test_bmag):
1217
        """Test the method to get the magnitude of the magnetic field."""
1218
        bmag = self.apex_out.get_babs(glat, glon, height)
1219
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1220
        return
1221
1222
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1223
    def test_get_apex_with_invalid_lat(self, bad_lat):
1224
        """Test get methods raise ValueError for invalid latitudes."""
1225
1226
        with pytest.raises(ValueError) as verr:
1227
            self.apex_out.get_apex(bad_lat)
1228
1229
        assert str(verr.value).find("must be in [-90, 90]") > 0
1230
        return
1231
1232
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1233
    def test_get_babs_with_invalid_lat(self, bad_lat):
1234
        """Test get methods raise ValueError for invalid latitudes."""
1235
1236
        with pytest.raises(ValueError) as verr:
1237
            self.apex_out.get_babs(bad_lat, 15, 100)
1238
1239
        assert str(verr.value).find("must be in [-90, 90]") > 0
1240
        return
1241
1242
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1243
    def test_get_at_lat_boundary(self, bound_lat):
1244
        """Test get methods at the latitude boundary, with allowed excess."""
1245
        # Get a latitude just beyond the limit
1246
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1247
1248
        # Get the two outputs, slight tolerance outside of boundary allowed
1249
        bound_out = self.apex_out.get_apex(bound_lat)
1250
        excess_out = self.apex_out.get_apex(excess_lat)
1251
1252
        # Test the outputs
1253
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1254
        return
1255