Passed
Pull Request — develop (#86)
by Angeline
01:20
created

TestApexGetMethods.test_get_height_at_equator()   A

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 13
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import warnings
20
21
import apexpy
22
23
24
@pytest.fixture()
25
def igrf_file():
26
    """A fixture for handling the coefficient file."""
27
    # Ensure the coefficient file exists
28
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
29
                                 'igrf13coeffs.txt')
30
    tmp_file = "temp_coeff.txt"
31
    assert os.path.isfile(original_file)
32
33
    # Move the coefficient file
34
    os.rename(original_file, tmp_file)
35
    yield original_file
36
37
    # Move the coefficient file back
38
    os.rename(tmp_file, original_file)
39
    return
40
41
42
def test_set_epoch_file_error(igrf_file):
43
    """Test raises OSError when IGRF coefficient file is missing."""
44
    # Test missing coefficient file failure
45
    with pytest.raises(OSError) as oerr:
46
        apexpy.Apex()
47
    error_string = "File {:} does not exist".format(igrf_file)
48
    assert str(oerr.value).startswith(error_string)
49
    return
50
51
52
class TestApexInit(object):
53
    def setup(self):
54
        self.apex_out = None
55
        self.test_date = dt.datetime.utcnow()
56
        self.test_refh = 0
57
        self.bad_file = 'foo/path/to/datafile.blah'
58
59
    def teardown(self):
60
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
61
62
    def eval_date(self):
63
        """Evaluate the times in self.test_date and self.apex_out."""
64
        if isinstance(self.test_date, dt.datetime) \
65
           or isinstance(self.test_date, dt.date):
66
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
67
68
        # Assert the times are the same on the order of tens of seconds.
69
        # Necessary to evaluate the current UTC
70
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
71
        return
72
73
    def eval_refh(self):
74
        """Evaluate the reference height in self.refh and self.apex_out."""
75
        eval_str = "".join(["expected reference height [",
76
                            "{:}] not equal to Apex ".format(self.test_refh),
77
                            "reference height ",
78
                            "[{:}]".format(self.apex_out.refh)])
79
        assert self.test_refh == self.apex_out.refh, eval_str
80
        return
81
82
    def test_init_defaults(self):
83
        """Test Apex class default initialization."""
84
        self.apex_out = apexpy.Apex()
85
        self.eval_date()
86
        self.eval_refh()
87
        return
88
89
    @pytest.mark.parametrize("in_date",
90
                             [2015, 2015.5, dt.date(2015, 1, 1),
91
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
92
    def test_init_date(self, in_date):
93
        """Test Apex class with date initialization."""
94
        self.test_date = in_date
95
        self.apex_out = apexpy.Apex(date=self.test_date)
96
        self.eval_date()
97
        self.eval_refh()
98
        return
99
100
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
101
    def test_set_epoch(self, new_date):
102
        """Test successful setting of Apex epoch after initialization."""
103
        # Evaluate the default initialization
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
108
        # Update the epoch
109
        ref_apex = eval(self.apex_out.__repr__())
110
        self.apex_out.set_epoch(new_date)
111
        assert ref_apex != self.apex_out
112
        self.test_date = new_date
113
        self.eval_date()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization."""
119
        self.test_refh = in_refh
120
        self.apex_out = apexpy.Apex(refh=self.test_refh)
121
        self.eval_date()
122
        self.eval_refh()
123
        return
124
125
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
126
    def test_set_refh(self, new_refh):
127
        """Test the method used to set the reference height after the init."""
128
        # Verify the defaults are set
129
        self.apex_out = apexpy.Apex(date=self.test_date)
130
        self.eval_date()
131
        self.eval_refh()
132
133
        # Update to a new reference height and test
134
        ref_apex = eval(self.apex_out.__repr__())
135
        self.apex_out.set_refh(new_refh)
136
137
        if self.test_refh == new_refh:
138
            assert ref_apex == self.apex_out
139
        else:
140
            assert ref_apex != self.apex_out
141
            self.test_refh = new_refh
142
        self.eval_refh()
143
        return
144
145
    def test_init_with_bad_datafile(self):
146
        """Test raises IOError with non-existent datafile input."""
147
        with pytest.raises(IOError) as oerr:
148
            apexpy.Apex(datafile=self.bad_file)
149
        assert str(oerr.value).startswith('Data file does not exist')
150
        return
151
152
    def test_init_with_bad_fortranlib(self):
153
        """Test raises IOError with non-existent datafile input."""
154
        with pytest.raises(IOError) as oerr:
155
            apexpy.Apex(fortranlib=self.bad_file)
156
        assert str(oerr.value).startswith('Fortran library does not exist')
157
        return
158
159
    def test_repr_eval(self):
160
        """Test the Apex.__repr__ results."""
161
        # Initialize the apex object
162
        self.apex_out = apexpy.Apex()
163
        self.eval_date()
164
        self.eval_refh()
165
166
        # Get and test the repr string
167
        out_str = self.apex_out.__repr__()
168
        assert out_str.find("apexpy.Apex(") == 0
169
170
        # Test the ability to re-create the apex object from the repr string
171
        new_apex = eval(out_str)
172
        assert new_apex == self.apex_out
173
        return
174
175
    def test_ne_other_class(self):
176
        """Test Apex class inequality to a different class."""
177
        self.apex_out = apexpy.Apex()
178
        self.eval_date()
179
        self.eval_refh()
180
181
        assert self.apex_out != self.test_date
182
        return
183
184
    def test_ne_missing_attr(self):
185
        """Test Apex class inequality when attributes are missing from one."""
186
        self.apex_out = apexpy.Apex()
187
        self.eval_date()
188
        self.eval_refh()
189
        ref_apex = eval(self.apex_out.__repr__())
190
        del ref_apex.RE
191
192
        assert ref_apex != self.apex_out
193
        assert self.apex_out != ref_apex
194
        return
195
196
    def test_eq_missing_attr(self):
197
        """Test Apex class equality when attributes are missing from both."""
198
        self.apex_out = apexpy.Apex()
199
        self.eval_date()
200
        self.eval_refh()
201
        ref_apex = eval(self.apex_out.__repr__())
202
        del ref_apex.RE, self.apex_out.RE
203
204
        assert ref_apex == self.apex_out
205
        return
206
207
    def test_str_eval(self):
208
        """Test the Apex.__str__ results."""
209
        # Initialize the apex object
210
        self.apex_out = apexpy.Apex()
211
        self.eval_date()
212
        self.eval_refh()
213
214
        # Get and test the printed string
215
        out_str = self.apex_out.__str__()
216
        assert out_str.find("Decimal year") > 0
217
        return
218
219
220
class TestApexMethod(object):
221
    """Test the Apex methods."""
222
    def setup(self):
223
        """Initialize all tests."""
224
        self.apex_out = apexpy.Apex(date=2000, refh=300)
225
        self.in_lat = 60
226
        self.in_lon = 15
227
        self.in_alt = 100
228
229
    def teardown(self):
230
        """Clean up after each test."""
231
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
232
233
    def get_input_args(self, method_name, precision=0.0):
234
        """Set the input arguments for the different Apex methods.
235
236
        Parameters
237
        ----------
238
        method_name : str
239
            Name of the Apex class method
240
        precision : float
241
            Value for the precision (default=0.0)
242
243
        Returns
244
        -------
245
        in_args : list
246
            List of the appropriate input arguments
247
248
        """
249
        in_args = [self.in_lat, self.in_lon, self.in_alt]
250
251
        # Add precision, if needed
252
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
253
                           "_apex2geo"]:
254
            in_args.append(precision)
255
256
        # Add a reference height, if needed
257
        if method_name in ["apxg2all"]:
258
            in_args.append(300)
259
260
        # Add a vector flag, if needed
261
        if method_name in ["apxg2all", "apxg2q"]:
262
            in_args.append(1)
263
264
        return in_args
265
266
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
267
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
268
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
269
                              ("_qd2geo", "apxq2g", slice(None)),
270
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
271
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
272
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
273
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
274
                                  lat, lon):
275
        """Tests Apex/fortran interface consistency for scalars."""
276
        # Set the input coordinates
277
        self.in_lat = lat
278
        self.in_lon = lon
279
280
        # Get the Apex class method and the fortran function call
281
        apex_func = getattr(self.apex_out, apex_method)
282
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
283
284
        # Get the appropriate input arguments
285
        apex_args = self.get_input_args(apex_method)
286
        fortran_args = self.get_input_args(fortran_method)
287
288
        # Evaluate the equivalent function calls
289
        np.testing.assert_allclose(apex_func(*apex_args),
290
                                   fortran_func(*fortran_args)[fslice])
291
        return
292
293
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
294
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
295
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
296
                              ("_qd2geo", "apxq2g", slice(None)),
297
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
298
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
299
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
300
                                           (180, -180), (-180, 180),
301
                                           (-345, 15), (375, 15)])
302
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
303
                                        fslice, lat, lon1, lon2):
304
        """Tests Apex/fortran interface consistency for longitude rollover."""
305
        # Set the fixed input coordinate
306
        self.in_lat = lat
307
308
        # Get the Apex class method and the fortran function call
309
        apex_func = getattr(self.apex_out, apex_method)
310
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
311
312
        # Get the appropriate input arguments
313
        self.in_lon = lon1
314
        apex_args = self.get_input_args(apex_method)
315
316
        self.in_lon = lon2
317
        fortran_args = self.get_input_args(fortran_method)
318
319
        # Evaluate the equivalent function calls
320
        np.testing.assert_allclose(apex_func(*apex_args),
321
                                   fortran_func(*fortran_args)[fslice])
322
        return
323
324
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
325
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
326
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
327
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
328
                              ("_qd2geo", "apxq2g", slice(None)),
329
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
330
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
331
                                 fslice):
332
        """Tests Apex/fortran interface consistency for array input."""
333
        # Get the Apex class method and the fortran function call
334
        apex_func = getattr(self.apex_out, apex_method)
335
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
336
337
        # Set up the input arrays
338
        ref_lat = np.array([0, 30, 60, 90])
339
        ref_alt = np.array([100, 200, 300, 400])
340
        self.in_lat = ref_lat.reshape(arr_shape)
341
        self.in_alt = ref_alt.reshape(arr_shape)
342
        apex_args = self.get_input_args(apex_method)
343
344
        # Get the Apex class results
345
        aret = apex_func(*apex_args)
346
347
        # Get the fortran function results
348
        flats = list()
349
        flons = list()
350
351
        for i, lat in enumerate(ref_lat):
352
            self.in_lat = lat
353
            self.in_alt = ref_alt[i]
354
            fortran_args = self.get_input_args(fortran_method)
355
            fret = fortran_func(*fortran_args)[fslice]
356
            flats.append(fret[0])
357
            flons.append(fret[1])
358
359
        flats = np.array(flats)
360
        flons = np.array(flons)
361
362
        # Evaluate results
363
        try:
364
            # This returned value is array of floats
365
            np.testing.assert_allclose(aret[0].astype(float),
366
                                       flats.reshape(arr_shape).astype(float))
367
            np.testing.assert_allclose(aret[1].astype(float),
368
                                       flons.reshape(arr_shape).astype(float))
369
        except ValueError:
370
            # This returned value is array of arrays
371
            alats = aret[0].reshape((4,))
372
            alons = aret[1].reshape((4,))
373
            for i, flat in enumerate(flats):
374
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
375
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
376
377
        return
378
379
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
380
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
381
    def test_geo2apexall_scalar(self, lat, lon):
382
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
383
        # Get the Apex and Fortran results
384
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
385
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
386
387
        # Evaluate each element in the results
388
        for aval, fval in zip(aret, fret):
389
            np.testing.assert_allclose(aval, fval)
390
391
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
392
    def test_geo2apexall_array(self, arr_shape):
393
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
394
        # Set the input
395
        self.in_lat = np.array([0, 30, 60, 90])
396
        self.in_alt = np.array([100, 200, 300, 400])
397
398
        # Get the Apex class results
399
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
400
                                          self.in_lon,
401
                                          self.in_alt.reshape(arr_shape))
402
403
        # For each lat/alt pair, get the Fortran results
404
        fret = list()
405
        for i, lat in enumerate(self.in_lat):
406
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
407
                                                    self.in_alt[i], 300, 1))
408
409
        # Cycle through all returned values
410
        for i, ret in enumerate(aret):
411
            try:
412
                # This returned value is array of floats
413
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
414
                                      fret[3][i]]).reshape(arr_shape)
415
                np.testing.assert_allclose(ret.astype(float),
416
                                           fret_test.astype(float))
417
            except ValueError:
418
                # This returned value is array of arrays
419
                ret = ret.reshape((4,))
420
                for j, single_fret in enumerate(fret):
421
                    np.testing.assert_allclose(ret[j], single_fret[i])
422
        return
423
424
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
425
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
426
    def test_convert_consistency(self, in_coord, out_coord):
427
        """Test the self-consistency of the Apex convert method."""
428
        if in_coord == out_coord:
429
            pytest.skip("Test not needed for same src and dest coordinates")
430
431
        # Define the method name
432
        method_name = "2".join([in_coord, out_coord])
433
434
        # Get the method and method inputs
435
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
436
        apex_args = self.get_input_args(method_name)
437
        apex_method = getattr(self.apex_out, method_name)
438
439
        # Define the slice needed to get equivalent output from the named method
440
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
441
442
        # Get output using convert and named method
443
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
444
                                            out_coord, **convert_kwargs)
445
        method_out = apex_method(*apex_args)[mslice]
446
447
        # Compare both outputs, should be identical
448
        np.testing.assert_allclose(convert_out, method_out)
449
        return
450
451
    @pytest.mark.parametrize("bound_lat", [90, -90])
452
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
453
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
454
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
455
        """Test the conversion at the latitude boundary, with allowed excess."""
456
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
457
458
        # Get the two outputs, slight tolerance outside of boundary allowed
459
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
460
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
461
462
        # Test the outputs
463
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
464
        return
465
466
    def test_convert_qd2apex_at_equator(self):
467
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
468
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
469
                                       height=320.0)
470
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
471
                                          dest='apex', height=320.0)
472
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
473
        return
474
475
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
476
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
477
    def test_convert_withnan(self, src, dest):
478
        """Test Apex.convert success with NaN input."""
479
        if src == dest:
480
            pytest.skip("Test not needed for same src and dest coordinates")
481
482
        num_nans = 5
483
        in_loc = np.arange(0, 10, dtype=float)
484
        in_loc[:num_nans] = np.nan
485
486
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
487
488
        for out in out_loc:
489
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
490
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
491
492
        return
493
494
    @pytest.mark.parametrize("bad_lat", [91, -91])
495
    def test_convert_invalid_lat(self, bad_lat):
496
        """Test convert raises ValueError for invalid latitudes."""
497
498
        with pytest.raises(ValueError) as verr:
499
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
500
501
        assert str(verr.value).find("must be in [-90, 90]") > 0
502
        return
503
504
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
505
                                        ("geo", "mlt")])
506
    def test_convert_invalid_transformation(self, coords):
507
        """Test raises NotImplementedError for bad coordinates."""
508
        if "mlt" in coords:
509
            estr = "datetime must be given for MLT calculations"
510
        else:
511
            estr = "Unknown coordinate transformation"
512
513
        with pytest.raises(ValueError) as verr:
514
            self.apex_out.convert(0, 0, *coords)
515
516
        assert str(verr).find(estr) >= 0
517
        return
518
519
    @pytest.mark.parametrize("method_name, out_comp",
520
                             [("geo2apex",
521
                               (55.94841766357422, 94.10684204101562)),
522
                              ("apex2geo",
523
                               (51.476322174072266, -66.22817993164062,
524
                                5.727287771151168e-06)),
525
                              ("geo2qd",
526
                               (56.531288146972656, 94.10684204101562)),
527
                              ("apex2qd", (60.498401178276744, 15.0)),
528
                              ("qd2apex", (59.49138097045895, 15.0))])
529
    def test_method_scalar_input(self, method_name, out_comp):
530
        """Test the user method against set values with scalars."""
531
        # Get the desired methods
532
        user_method = getattr(self.apex_out, method_name)
533
534
        # Get the user output
535
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
536
537
        # Evaluate the user output
538
        np.testing.assert_allclose(user_out, out_comp)
539
540
        for out_val in user_out:
541
            assert np.asarray(out_val).shape == (), "output is not a scalar"
542
        return
543
544
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
545
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
546
    @pytest.mark.parametrize("method_args, out_shape",
547
                             [([[60, 60], 15, 100], (2,)),
548
                              ([60, [15, 15], 100], (2,)),
549
                              ([60, 15, [100, 100]], (2,)),
550
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
551
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
552
                                    out_shape):
553
        """Test the user method with inputs that require some broadcasting."""
554
        if in_coord == out_coord:
555
            pytest.skip("Test not needed for same src and dest coordinates")
556
557
        # Get the desired methods
558
        method_name = "2".join([in_coord, out_coord])
559
        user_method = getattr(self.apex_out, method_name)
560
561
        # Get the user output
562
        user_out = user_method(*method_args)
563
564
        # Evaluate the user output
565
        for out_val in user_out:
566
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
567
            assert out_val.shape == out_shape
568
        return
569
570
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
571
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
572
    @pytest.mark.parametrize("bad_lat", [91, -91])
573
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
574
        """Test convert raises ValueError for invalid latitudes."""
575
        if in_coord == out_coord:
576
            pytest.skip("Test not needed for same src and dest coordinates")
577
578
        # Get the desired methods
579
        method_name = "2".join([in_coord, out_coord])
580
        user_method = getattr(self.apex_out, method_name)
581
582
        with pytest.raises(ValueError) as verr:
583
            user_method(bad_lat, 15, 100)
584
585
        assert str(verr.value).find("must be in [-90, 90]") > 0
586
        return
587
588
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
589
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
590
    @pytest.mark.parametrize("bound_lat", [90, -90])
591
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
592
        """Test user methods at the latitude boundary, with allowed excess."""
593
        if in_coord == out_coord:
594
            pytest.skip("Test not needed for same src and dest coordinates")
595
596
        # Get the desired methods
597
        method_name = "2".join([in_coord, out_coord])
598
        user_method = getattr(self.apex_out, method_name)
599
600
        # Get a latitude just beyond the limit
601
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
602
603
        # Get the two outputs, slight tolerance outside of boundary allowed
604
        bound_out = user_method(bound_lat, 0, 100)
605
        excess_out = user_method(excess_lat, 0, 100)
606
607
        # Test the outputs
608
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
609
        return
610
611
    def test_geo2apex_undefined_warning(self):
612
        """Test geo2apex warning and fill values for an undefined location."""
613
614
        # Update the apex object
615
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
616
617
        # Get the output and the warnings
618
        with warnings.catch_warnings(record=True) as warn_rec:
619
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
620
621
        assert np.isnan(user_lat)
622
        assert np.isfinite(user_lon)
623
        assert len(warn_rec) == 1
624
        assert issubclass(warn_rec[-1].category, UserWarning)
625
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
626
        return
627
628
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
629
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
630
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
631
        """Test quasi-dipole success with a height close to the reference."""
632
        qd_method = getattr(self.apex_out, method_name)
633
        in_args = [0, 15, self.apex_out.refh + delta_h]
634
        out_coords = qd_method(*in_args)
635
636
        for i, out_val in enumerate(out_coords):
637
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
638
        return
639
640
    @pytest.mark.parametrize("method_name, hinc, msg",
641
                             [("apex2qd", 1.0, "is > apex height"),
642
                              ("qd2apex", -1.0, "is < reference height")])
643
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
644
        """Quasi-dipole raises ApexHeightError when height above reference."""
645
        qd_method = getattr(self.apex_out, method_name)
646
647
        with pytest.raises(apexpy.ApexHeightError) as aerr:
648
            qd_method(0, 15, self.apex_out.refh + hinc)
649
650
        assert str(aerr).find(msg) > 0
651
        return
652
653
654
class TestApexMLTMethods(object):
655
    """Test the Apex Magnetic Local Time (MLT) methods."""
656
    def setup(self):
657
        """Initialize all tests."""
658
        self.apex_out = apexpy.Apex(date=2000, refh=300)
659
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
660
661
    def teardown(self):
662
        """Clean up after each test."""
663
        del self.apex_out, self.in_time
664
665
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
666
    def test_convert_to_mlt(self, in_coord):
667
        """Test the conversions to MLT using Apex convert."""
668
669
        # Get the magnetic longitude from the appropriate method
670
        if in_coord == "geo":
671
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
672
            mlon = apex_method(60, 15, 100)[1]
673
        else:
674
            mlon = 15
675
676
        # Get the output MLT values
677
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
678
                                            height=100, ssheight=2e5,
679
                                            datetime=self.in_time)[1]
680
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
681
682
        # Test the outputs
683
        np.testing.assert_allclose(convert_mlt, method_mlt)
684
        return
685
686
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
687
    def test_convert_mlt_to_lon(self, out_coord):
688
        """Test the conversions from MLT using Apex convert."""
689
        # Get the output longitudes
690
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
691
                                            height=100, ssheight=2e5,
692
                                            datetime=self.in_time,
693
                                            precision=1e-2)
694
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
695
696
        if out_coord == "geo":
697
            method_out = self.apex_out.apex2geo(60, mlon, 100,
698
                                                precision=1e-2)[:-1]
699
        elif out_coord == "qd":
700
            method_out = self.apex_out.apex2qd(60, mlon, 100)
701
        else:
702
            method_out = (60, mlon)
703
704
        # Evaluate the outputs
705
        np.testing.assert_allclose(convert_out, method_out)
706
        return
707
708
    def test_convert_geo2mlt_nodate(self):
709
        """Test convert from geo to MLT raises ValueError with no datetime."""
710
        with pytest.raises(ValueError):
711
            self.apex_out.convert(60, 15, 'geo', 'mlt')
712
        return
713
714
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
715
                             [({}, 23.019629923502603),
716
                              ({"ssheight": 100000}, 23.026712036132814)])
717
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
718
        """Test mlon2mlt with scalar inputs."""
719
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
720
721
        np.testing.assert_allclose(mlt, test_mlt)
722
        assert np.asarray(mlt).shape == ()
723
        return
724
725
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
726
                             [({}, 14.705535888671875),
727
                              ({"ssheight": 100000}, 14.599319458007812)])
728
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
729
        """Test mlt2mlon with scalar inputs."""
730
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
731
732
        np.testing.assert_allclose(mlon, test_mlon)
733
        assert np.asarray(mlon).shape == ()
734
        return
735
736
    @pytest.mark.parametrize("mlon,test_mlt",
737
                             [([0, 180], [23.019261, 11.019261]),
738
                              (np.array([0, 180]), [23.019261, 11.019261]),
739
                              (np.array([[0], [180]]),
740
                               np.array([[23.019261], [11.019261]])),
741
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
742
                                                      [23.019261, 11.019261]]),
743
                              (range(0, 361, 30),
744
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
745
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
746
                                19.01963, 21.01963, 23.01963])])
747
    def test_mlon2mlt_array(self, mlon, test_mlt):
748
        """Test mlon2mlt with array inputs."""
749
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
750
751
        assert mlt.shape == np.asarray(test_mlt).shape
752
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
753
        return
754
755
    @pytest.mark.parametrize("mlt,test_mlon",
756
                             [([0, 12], [14.705551, 194.705551]),
757
                              (np.array([0, 12]), [14.705551, 194.705551]),
758
                              (np.array([[0], [12]]),
759
                               np.array([[14.705551], [194.705551]])),
760
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
761
                                                    [14.705551, 194.705551]]),
762
                              (range(0, 25, 2),
763
                               [14.705551, 44.705551, 74.705551, 104.705551,
764
                                134.705551, 164.705551, 194.705551, 224.705551,
765
                                254.705551, 284.705551, 314.705551, 344.705551,
766
                                14.705551])])
767
    def test_mlt2mlon_array(self, mlt, test_mlon):
768
        """Test mlt2mlon with array inputs."""
769
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
770
771
        assert mlon.shape == np.asarray(test_mlon).shape
772
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
773
        return
774
775
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
776
    def test_mlon2mlt_diffdates(self, method_name):
777
        """Test that MLT varies with universal time."""
778
        apex_method = getattr(self.apex_out, method_name)
779
        mlt1 = apex_method(0, self.in_time)
780
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
781
782
        assert mlt1 != mlt2
783
        return
784
785
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
786
    def test_mlon2mlt_offset(self, mlt_offset):
787
        """Test the time wrapping logic for the MLT."""
788
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
789
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
790
                                      self.in_time) + mlt_offset
791
792
        np.testing.assert_allclose(mlt1, mlt2)
793
        return
794
795
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
796
    def test_mlt2mlon_offset(self, mlon_offset):
797
        """Test the time wrapping logic for the magnetic longitude."""
798
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
799
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
800
                                       self.in_time) - mlon_offset
801
802
        np.testing.assert_allclose(mlon1, mlon2)
803
        return
804
805
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
806
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
807
    def test_convert_and_return(self, order, start_val):
808
        """Test the conversion to magnetic longitude or MLT and back again."""
809
        first_method = getattr(self.apex_out, "2".join(order))
810
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
811
812
        middle_val = first_method(start_val, self.in_time)
813
        end_val = second_method(middle_val, self.in_time)
814
815
        np.testing.assert_allclose(start_val, end_val)
816
        return
817
818
819
class TestApexMapMethods(object):
820
    """Test the Apex height mapping methods."""
821
    def setup(self):
822
        """Initialize all tests."""
823
        self.apex_out = apexpy.Apex(date=2000, refh=300)
824
825
    def teardown(self):
826
        """Clean up after each test."""
827
        del self.apex_out
828
829
    @pytest.mark.parametrize("in_args,test_mapped",
830
                             [([60, 15, 100, 10000],
831
                               [31.841466903686523, 17.916635513305664,
832
                                1.7075473124350538e-6]),
833
                              ([30, 170, 100, 500, False, 1e-2],
834
                               [25.727270126342773, 169.60546875,
835
                                0.00017573432705830783]),
836
                              ([60, 15, 100, 10000, True],
837
                               [-25.424888610839844, 27.310426712036133,
838
                                1.2074182222931995e-6]),
839
                              ([30, 170, 100, 500, True, 1e-2],
840
                               [-13.76642894744873, 164.24259948730469,
841
                                0.00056820799363777041])])
842
    def test_map_to_height(self, in_args, test_mapped):
843
        """Test the map_to_height function."""
844
        mapped = self.apex_out.map_to_height(*in_args)
845
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
846
        return
847
848
    def test_map_to_height_same_height(self):
849
        """Test the map_to_height function when mapping to same height."""
850
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
851
                                             precision=1e-10)
852
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
853
                                   rtol=1e-5)
854
        return
855
856
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
857
    @pytest.mark.parametrize('ivec', range(0, 4))
858
    def test_map_to_height_array_location(self, arr_shape, ivec):
859
        """Test map_to_height with array input."""
860
        # Set the base input and output values
861
        in_args = [60, 15, 100, 100]
862
        test_mapped = [60, 15.00000381, 0.0]
863
864
        # Update inputs for one vectorized value
865
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
866
867
        # Calculate and test function
868
        mapped = self.apex_out.map_to_height(*in_args)
869
        for i, test_val in enumerate(test_mapped):
870
            assert mapped[i].shape == arr_shape
871
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
872
        return
873
874
    @pytest.mark.parametrize("method_name,in_args",
875
                             [("map_to_height", [0, 15, 100, 10000]),
876
                              ("map_E_to_height",
877
                               [0, 15, 100, 10000, [1, 2, 3]]),
878
                              ("map_V_to_height",
879
                               [0, 15, 100, 10000, [1, 2, 3]])])
880
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
881
        """Test map_to_height raises ApexHeightError."""
882
        apex_method = getattr(self.apex_out, method_name)
883
884
        with pytest.raises(apexpy.ApexHeightError) as aerr:
885
            apex_method(*in_args)
886
887
        assert aerr.match("is > apex height")
888
        return
889
890
    @pytest.mark.parametrize("method_name",
891
                             ["map_E_to_height", "map_V_to_height"])
892
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
893
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
894
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
895
        """Test height mapping of E/V with baddly shaped input raises Error."""
896
        apex_method = getattr(self.apex_out, method_name)
897
        in_args = [60, 15, 100, 500, ev_input]
898
        with pytest.raises(ValueError) as verr:
899
            apex_method(*in_args)
900
901
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
902
        return
903
904
    def test_mapping_EV_bad_flag(self):
905
        """Test _map_EV_to_height raises error for bad data type flag."""
906
        with pytest.raises(ValueError) as verr:
907
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
908
909
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
910
        return
911
912
    @pytest.mark.parametrize("in_args,test_mapped",
913
                             [([60, 15, 100, 500, [1, 2, 3]],
914
                               [0.71152183, 2.35624876, 0.57260784]),
915
                              ([60, 15, 100, 500, [2, 3, 4]],
916
                               [1.56028502, 3.43916636, 0.78235384]),
917
                              ([60, 15, 100, 1000, [1, 2, 3]],
918
                               [0.67796492, 2.08982134, 0.55860785]),
919
                              ([60, 15, 200, 500, [1, 2, 3]],
920
                               [0.72377397, 2.42737471, 0.59083726]),
921
                              ([60, 30, 100, 500, [1, 2, 3]],
922
                               [0.68626344, 2.37530133, 0.60060124]),
923
                              ([70, 15, 100, 500, [1, 2, 3]],
924
                               [0.72760378, 2.18082305, 0.29141979])])
925
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
926
        """Test mapping of E-field to a specified height."""
927
        mapped = self.apex_out.map_E_to_height(*in_args)
928
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
929
        return
930
931
    @pytest.mark.parametrize('ev_flag, test_mapped',
932
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
933
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
934
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
935
    @pytest.mark.parametrize('ivec', range(0, 5))
936
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
937
                                             arr_shape, ivec):
938
        """Test mapping of E-field/drift to a specified height with arrays."""
939
        # Set the base input and output values
940
        eshape = list(arr_shape)
941
        eshape.insert(0, 3)
942
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
943
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
944
945
        # Update inputs for one vectorized value if this is a location input
946
        if ivec < 4:
947
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
948
949
        # Get the mapped output
950
        apex_method = getattr(self.apex_out,
951
                              "map_{:s}_to_height".format(ev_flag))
952
        mapped = apex_method(*in_args)
953
954
        # Test the results
955
        for i, test_val in enumerate(test_mapped):
956
            assert mapped[i].shape == arr_shape
957
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
958
        return
959
960
    @pytest.mark.parametrize("in_args,test_mapped",
961
                             [([60, 15, 100, 500, [1, 2, 3]],
962
                               [0.81971957, 2.84512495, 0.69545001]),
963
                              ([60, 15, 100, 500, [2, 3, 4]],
964
                               [1.83027746, 4.14346436, 0.94764179]),
965
                              ([60, 15, 100, 1000, [1, 2, 3]],
966
                               [0.92457698, 3.14997661, 0.85135187]),
967
                              ([60, 15, 200, 500, [1, 2, 3]],
968
                               [0.80388262, 2.79321504, 0.68285158]),
969
                              ([60, 30, 100, 500, [1, 2, 3]],
970
                               [0.76141245, 2.87884673, 0.73655941]),
971
                              ([70, 15, 100, 500, [1, 2, 3]],
972
                               [0.84681866, 2.5925821,  0.34792655])])
973
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
974
        """Test mapping of velocity to a specified height."""
975
        mapped = self.apex_out.map_V_to_height(*in_args)
976
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
977
        return
978
979
980
class TestApexBasevectorMethods(object):
981
    """Test the Apex height base vector methods."""
982
    def setup(self):
983
        """Initialize all tests."""
984
        self.apex_out = apexpy.Apex(date=2000, refh=300)
985
        self.lat = 60
986
        self.lon = 15
987
        self.height = 100
988
        self.test_basevec = None
989
990
    def teardown(self):
991
        """Clean up after each test."""
992
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
993
994
    def get_comparison_results(self, bv_coord, coords, precision):
995
        """Get the base vector results using the hidden function for comparison.
996
997
        Parameters
998
        ----------
999
        bv_coord : str
1000
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1001
            or 'bvectors_apex'
1002
        coords : str
1003
            Expects one of 'geo', 'apex', or 'qd'
1004
        precision : float
1005
            Float specifiying precision
1006
1007
        """
1008
        if coords == "geo":
1009
            glat = self.lat
1010
            glon = self.lon
1011
        else:
1012
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1013
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1014
                                        precision=precision)
1015
1016
        if bv_coord == 'qd':
1017
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1018
        elif bv_coord == 'apex':
1019
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1020
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1021
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1022
        else:
1023
            # These are set results that need to be updated with IGRF
1024
            if coords == "geo":
1025
                self.test_basevec = (
1026
                    np.array([4.42368795e-05, 4.42368795e-05]),
1027
                    np.array([[0.01047826, 0.01047826],
1028
                              [0.33089194, 0.33089194],
1029
                              [-1.04941, -1.04941]]),
1030
                    np.array([5.3564698e-05, 5.3564698e-05]),
1031
                    np.array([[0.00865356, 0.00865356],
1032
                              [0.27327004, 0.27327004],
1033
                              [-0.8666646, -0.8666646]]))
1034
            elif coords == "apex":
1035
                self.test_basevec = (
1036
                    np.array([4.48672735e-05, 4.48672735e-05]),
1037
                    np.array([[-0.12510721, -0.12510721],
1038
                              [0.28945938, 0.28945938],
1039
                              [-1.1505738, -1.1505738]]),
1040
                    np.array([6.38577444e-05, 6.38577444e-05]),
1041
                    np.array([[-0.08790194, -0.08790194],
1042
                              [0.2033779, 0.2033779],
1043
                              [-0.808408, -0.808408]]))
1044
            else:
1045
                self.test_basevec = (
1046
                    np.array([4.46348578e-05, 4.46348578e-05]),
1047
                    np.array([[-0.12642345, -0.12642345],
1048
                              [0.29695055, 0.29695055],
1049
                              [-1.1517885, -1.1517885]]),
1050
                    np.array([6.38626285e-05, 6.38626285e-05]),
1051
                    np.array([[-0.08835986, -0.08835986],
1052
                              [0.20754464, 0.20754464],
1053
                              [-0.8050078, -0.8050078]]))
1054
1055
        return
1056
1057
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1058
    @pytest.mark.parametrize("coords,precision",
1059
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1060
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1061
        """Test the base vector calculations with scalars."""
1062
        # Get the base vectors
1063
        base_method = getattr(self.apex_out,
1064
                              "basevectors_{:s}".format(bv_coord))
1065
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1066
                              precision=precision)
1067
        self.get_comparison_results(bv_coord, coords, precision)
1068
        if bv_coord == "apex":
1069
            basevec = list(basevec)
1070
            for i in range(4):
1071
                # Not able to compare indices 2, 3, 4, and 5
1072
                basevec.pop(2)
1073
1074
        # Test the results
1075
        for i, vec in enumerate(basevec):
1076
            np.testing.assert_allclose(vec, self.test_basevec[i])
1077
        return
1078
1079
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1080
    def test_basevectors_scalar_shape(self, bv_coord):
1081
        """Test the shape of the scalar output."""
1082
        base_method = getattr(self.apex_out,
1083
                              "basevectors_{:s}".format(bv_coord))
1084
        basevec = base_method(self.lat, self.lon, self.height)
1085
1086
        for i, vec in enumerate(basevec):
1087
            if i < 2:
1088
                assert vec.shape == (2,)
1089
            else:
1090
                assert vec.shape == (3,)
1091
        return
1092
1093
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1094
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1095
    @pytest.mark.parametrize("ivec", range(3))
1096
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1097
        """Test the output shape for array inputs."""
1098
        # Define the input arguments
1099
        in_args = [self.lat, self.lon, self.height]
1100
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1101
1102
        # Get the basevectors
1103
        base_method = getattr(self.apex_out,
1104
                              "basevectors_{:s}".format(bv_coord))
1105
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1106
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1107
        if bv_coord == "apex":
1108
            basevec = list(basevec)
1109
            for i in range(4):
1110
                # Not able to compare indices 2, 3, 4, and 5
1111
                basevec.pop(2)
1112
1113
        # Evaluate the shape and the values
1114
        for i, vec in enumerate(basevec):
1115
            test_shape = list(arr_shape)
1116
            test_shape.insert(0, 2 if i < 2 else 3)
1117
            assert vec.shape == tuple(test_shape)
1118
            assert np.all(self.test_basevec[i][0] == vec[0])
1119
            assert np.all(self.test_basevec[i][1] == vec[1])
1120
        return
1121
1122
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1123
    def test_bvectors_apex(self, coords):
1124
        """Test the bvectors_apex method."""
1125
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1126
                   [self.height, self.height]]
1127
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1128
1129
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1130
                                              precision=1e-10)
1131
        for i, vec in enumerate(basevec):
1132
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1133
                                                 decimal=5)
1134
        return
1135
1136
    def test_basevectors_apex_extra_values(self):
1137
        """Test specific values in the apex base vector output."""
1138
        # Set the testing arrays
1139
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1140
                             np.array([0.939012, 0.073416, -0.07342]),
1141
                             np.array([0.055389, 1.004155, 0.257594]),
1142
                             np.array([0, 0, 1.065135])]
1143
1144
        # Get the desired output
1145
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1146
1147
        # Test the values not covered by `test_basevectors_scalar`
1148
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1149
            np.testing.assert_allclose(basevec[ibase],
1150
                                       self.test_basevec[itest], rtol=1e-4)
1151
        return
1152
1153
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1154
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1155
    def test_basevectors_apex_delta(self, lat, lon):
1156
        """Test that vectors are calculated correctly."""
1157
        # Get the apex base vectors and sort them for easy testing
1158
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1159
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1160
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1161
        gvec = [g1, g2, g3]
1162
        dvec = [d1, d2, d3]
1163
        evec = [e1, e2, e3]
1164
1165
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1166
            delta = 1 if idelta == jdelta else 0
1167
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1168
                                       delta, rtol=0, atol=1e-5)
1169
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1170
                                       delta, rtol=0, atol=1e-5)
1171
        return
1172
1173
    def test_basevectors_apex_invalid_scalar(self):
1174
        """Test warning and fill values for base vectors with bad inputs."""
1175
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1176
        invalid = np.full(shape=(3,), fill_value=np.nan)
1177
1178
        # Get the output and the warnings
1179
        with warnings.catch_warnings(record=True) as warn_rec:
1180
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1181
1182
        for i, bvec in enumerate(basevec):
1183
            if i < 2:
1184
                assert not np.allclose(bvec, invalid[:2])
1185
            else:
1186
                np.testing.assert_allclose(bvec, invalid)
1187
1188
        assert issubclass(warn_rec[-1].category, UserWarning)
1189
        assert 'set to NaN where' in str(warn_rec[-1].message)
1190
        return
1191
1192
1193
class TestApexGetMethods(object):
1194
    """Test the Apex `get` methods."""
1195
    def setup(self):
1196
        """Initialize all tests."""
1197
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1198
1199
    def teardown(self):
1200
        """Clean up after each test."""
1201
        del self.apex_out
1202
1203
    @pytest.mark.parametrize("alat, aheight",
1204
                             [(10, 507.409702543805),
1205
                              (60, 20313.026999999987),
1206
                              ([10, 60],
1207
                               [507.409702543805, 20313.026999999987]),
1208
                              ([[10], [60]],
1209
                               [[507.409702543805], [20313.026999999987]])])
1210
    def test_get_apex(self, alat, aheight):
1211
        """Test the apex height retrieval results."""
1212
        alt = self.apex_out.get_apex(alat)
1213
        np.testing.assert_allclose(alt, aheight)
1214
        return
1215
1216
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1217
                             [([80], [100], [300], 5.100682377815247e-05),
1218
                              ([80, 80], [100], [300],
1219
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1220
                              ([[80], [80]], [100], [300],
1221
                               [[5.100682377815247e-05],
1222
                                [5.100682377815247e-05]]),
1223
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1224
                               np.array([4.18657154e-05, 5.11118114e-05,
1225
                                         4.91969854e-05, 5.10519207e-05,
1226
                                         4.90054816e-05])),
1227
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1228
    def test_get_babs(self, glat, glon, height, test_bmag):
1229
        """Test the method to get the magnitude of the magnetic field."""
1230
        bmag = self.apex_out.get_babs(glat, glon, height)
1231
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1232
        return
1233
1234
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1235
    def test_get_apex_with_invalid_lat(self, bad_lat):
1236
        """Test get methods raise ValueError for invalid latitudes."""
1237
1238
        with pytest.raises(ValueError) as verr:
1239
            self.apex_out.get_apex(bad_lat)
1240
1241
        assert str(verr.value).find("must be in [-90, 90]") > 0
1242
        return
1243
1244
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1245
    def test_get_babs_with_invalid_lat(self, bad_lat):
1246
        """Test get methods raise ValueError for invalid latitudes."""
1247
1248
        with pytest.raises(ValueError) as verr:
1249
            self.apex_out.get_babs(bad_lat, 15, 100)
1250
1251
        assert str(verr.value).find("must be in [-90, 90]") > 0
1252
        return
1253
1254
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1255
    def test_get_at_lat_boundary(self, bound_lat):
1256
        """Test get methods at the latitude boundary, with allowed excess."""
1257
        # Get a latitude just beyond the limit
1258
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1259
1260
        # Get the two outputs, slight tolerance outside of boundary allowed
1261
        bound_out = self.apex_out.get_apex(bound_lat)
1262
        excess_out = self.apex_out.get_apex(excess_lat)
1263
1264
        # Test the outputs
1265
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1266
        return
1267
1268
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1269
    def test_get_height_at_equator(self, apex_height):
1270
        """Test that `get_height` returns apex height at equator.
1271
1272
        Parameters
1273
        ----------
1274
        apex_height : float
1275
            Apex height
1276
1277
        """
1278
1279
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1280
        return
1281
1282
    @pytest.mark.parametrize("lat, height", [
1283
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1284
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1285
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1286
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1287
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1288
        (30, 657.2477500000014), (40, -871.8751821247979),
1289
        (50, -2499.1338178752017), (60, -4028.256749999999),
1290
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1291
    def test_get_height_along_fieldline(self, lat, height):
1292
        """Test that `get_height` returns expected height of field line.
1293
1294
        Parameters
1295
        ----------
1296
        lat : float
1297
            Input latitude
1298
        height : float
1299
            Output field-line height for line with apex of 3000 km
1300
1301
        """
1302
1303
        assert height == self.apex_out.get_height(lat, 3000.0)
1304
        return
1305
1306
    
1307