Passed
Pull Request — develop (#86)
by Angeline
01:29
created

TestApexMethod.test_quasidipole_apexheight_close()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 11
rs 9.95
c 0
b 0
f 0
cc 2
nop 3
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for retry in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for retry in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    def setup(self):
72
        self.apex_out = None
73
        self.test_date = dt.datetime.utcnow()
74
        self.test_refh = 0
75
        self.bad_file = 'foo/path/to/datafile.blah'
76
77
    def teardown(self):
78
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
79
80
    def eval_date(self):
81
        """Evaluate the times in self.test_date and self.apex_out."""
82
        if isinstance(self.test_date, dt.datetime) \
83
           or isinstance(self.test_date, dt.date):
84
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
85
86
        # Assert the times are the same on the order of tens of seconds.
87
        # Necessary to evaluate the current UTC
88
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
89
        return
90
91
    def eval_refh(self):
92
        """Evaluate the reference height in self.refh and self.apex_out."""
93
        eval_str = "".join(["expected reference height [",
94
                            "{:}] not equal to Apex ".format(self.test_refh),
95
                            "reference height ",
96
                            "[{:}]".format(self.apex_out.refh)])
97
        assert self.test_refh == self.apex_out.refh, eval_str
98
        return
99
100
    def test_init_defaults(self):
101
        """Test Apex class default initialization."""
102
        self.apex_out = apexpy.Apex()
103
        self.eval_date()
104
        self.eval_refh()
105
        return
106
107
    @pytest.mark.parametrize("in_date",
108
                             [2015, 2015.5, dt.date(2015, 1, 1),
109
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
110
    def test_init_date(self, in_date):
111
        """Test Apex class with date initialization."""
112
        self.test_date = in_date
113
        self.apex_out = apexpy.Apex(date=self.test_date)
114
        self.eval_date()
115
        self.eval_refh()
116
        return
117
118
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
119
    def test_set_epoch(self, new_date):
120
        """Test successful setting of Apex epoch after initialization."""
121
        # Evaluate the default initialization
122
        self.apex_out = apexpy.Apex()
123
        self.eval_date()
124
        self.eval_refh()
125
126
        # Update the epoch
127
        ref_apex = eval(self.apex_out.__repr__())
128
        self.apex_out.set_epoch(new_date)
129
        assert ref_apex != self.apex_out
130
        self.test_date = new_date
131
        self.eval_date()
132
        return
133
134
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
135
    def test_init_refh(self, in_refh):
136
        """Test Apex class with reference height initialization."""
137
        self.test_refh = in_refh
138
        self.apex_out = apexpy.Apex(refh=self.test_refh)
139
        self.eval_date()
140
        self.eval_refh()
141
        return
142
143
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
144
    def test_set_refh(self, new_refh):
145
        """Test the method used to set the reference height after the init."""
146
        # Verify the defaults are set
147
        self.apex_out = apexpy.Apex(date=self.test_date)
148
        self.eval_date()
149
        self.eval_refh()
150
151
        # Update to a new reference height and test
152
        ref_apex = eval(self.apex_out.__repr__())
153
        self.apex_out.set_refh(new_refh)
154
155
        if self.test_refh == new_refh:
156
            assert ref_apex == self.apex_out
157
        else:
158
            assert ref_apex != self.apex_out
159
            self.test_refh = new_refh
160
        self.eval_refh()
161
        return
162
163
    def test_init_with_bad_datafile(self):
164
        """Test raises IOError with non-existent datafile input."""
165
        with pytest.raises(IOError) as oerr:
166
            apexpy.Apex(datafile=self.bad_file)
167
        assert str(oerr.value).startswith('Data file does not exist')
168
        return
169
170
    def test_init_with_bad_fortranlib(self):
171
        """Test raises IOError with non-existent datafile input."""
172
        with pytest.raises(IOError) as oerr:
173
            apexpy.Apex(fortranlib=self.bad_file)
174
        assert str(oerr.value).startswith('Fortran library does not exist')
175
        return
176
177
    def test_repr_eval(self):
178
        """Test the Apex.__repr__ results."""
179
        # Initialize the apex object
180
        self.apex_out = apexpy.Apex()
181
        self.eval_date()
182
        self.eval_refh()
183
184
        # Get and test the repr string
185
        out_str = self.apex_out.__repr__()
186
        assert out_str.find("apexpy.Apex(") == 0
187
188
        # Test the ability to re-create the apex object from the repr string
189
        new_apex = eval(out_str)
190
        assert new_apex == self.apex_out
191
        return
192
193
    def test_ne_other_class(self):
194
        """Test Apex class inequality to a different class."""
195
        self.apex_out = apexpy.Apex()
196
        self.eval_date()
197
        self.eval_refh()
198
199
        assert self.apex_out != self.test_date
200
        return
201
202
    def test_ne_missing_attr(self):
203
        """Test Apex class inequality when attributes are missing from one."""
204
        self.apex_out = apexpy.Apex()
205
        self.eval_date()
206
        self.eval_refh()
207
        ref_apex = eval(self.apex_out.__repr__())
208
        del ref_apex.RE
209
210
        assert ref_apex != self.apex_out
211
        assert self.apex_out != ref_apex
212
        return
213
214
    def test_eq_missing_attr(self):
215
        """Test Apex class equality when attributes are missing from both."""
216
        self.apex_out = apexpy.Apex()
217
        self.eval_date()
218
        self.eval_refh()
219
        ref_apex = eval(self.apex_out.__repr__())
220
        del ref_apex.RE, self.apex_out.RE
221
222
        assert ref_apex == self.apex_out
223
        return
224
225
    def test_str_eval(self):
226
        """Test the Apex.__str__ results."""
227
        # Initialize the apex object
228
        self.apex_out = apexpy.Apex()
229
        self.eval_date()
230
        self.eval_refh()
231
232
        # Get and test the printed string
233
        out_str = self.apex_out.__str__()
234
        assert out_str.find("Decimal year") > 0
235
        return
236
237
238
class TestApexMethod(object):
239
    """Test the Apex methods."""
240
    def setup(self):
241
        """Initialize all tests."""
242
        self.apex_out = apexpy.Apex(date=2000, refh=300)
243
        self.in_lat = 60
244
        self.in_lon = 15
245
        self.in_alt = 100
246
247
    def teardown(self):
248
        """Clean up after each test."""
249
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
250
251
    def get_input_args(self, method_name, precision=0.0):
252
        """Set the input arguments for the different Apex methods.
253
254
        Parameters
255
        ----------
256
        method_name : str
257
            Name of the Apex class method
258
        precision : float
259
            Value for the precision (default=0.0)
260
261
        Returns
262
        -------
263
        in_args : list
264
            List of the appropriate input arguments
265
266
        """
267
        in_args = [self.in_lat, self.in_lon, self.in_alt]
268
269
        # Add precision, if needed
270
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
271
                           "_apex2geo"]:
272
            in_args.append(precision)
273
274
        # Add a reference height, if needed
275
        if method_name in ["apxg2all"]:
276
            in_args.append(300)
277
278
        # Add a vector flag, if needed
279
        if method_name in ["apxg2all", "apxg2q"]:
280
            in_args.append(1)
281
282
        return in_args
283
284
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
285
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
286
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
287
                              ("_qd2geo", "apxq2g", slice(None)),
288
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
289
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
290
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
291
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
292
                                  lat, lon):
293
        """Tests Apex/fortran interface consistency for scalars."""
294
        # Set the input coordinates
295
        self.in_lat = lat
296
        self.in_lon = lon
297
298
        # Get the Apex class method and the fortran function call
299
        apex_func = getattr(self.apex_out, apex_method)
300
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
301
302
        # Get the appropriate input arguments
303
        apex_args = self.get_input_args(apex_method)
304
        fortran_args = self.get_input_args(fortran_method)
305
306
        # Evaluate the equivalent function calls
307
        np.testing.assert_allclose(apex_func(*apex_args),
308
                                   fortran_func(*fortran_args)[fslice])
309
        return
310
311
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
312
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
313
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
314
                              ("_qd2geo", "apxq2g", slice(None)),
315
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
316
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
317
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
318
                                           (180, -180), (-180, 180),
319
                                           (-345, 15), (375, 15)])
320
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
321
                                        fslice, lat, lon1, lon2):
322
        """Tests Apex/fortran interface consistency for longitude rollover."""
323
        # Set the fixed input coordinate
324
        self.in_lat = lat
325
326
        # Get the Apex class method and the fortran function call
327
        apex_func = getattr(self.apex_out, apex_method)
328
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
329
330
        # Get the appropriate input arguments
331
        self.in_lon = lon1
332
        apex_args = self.get_input_args(apex_method)
333
334
        self.in_lon = lon2
335
        fortran_args = self.get_input_args(fortran_method)
336
337
        # Evaluate the equivalent function calls
338
        np.testing.assert_allclose(apex_func(*apex_args),
339
                                   fortran_func(*fortran_args)[fslice])
340
        return
341
342
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
343
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
344
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
345
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
346
                              ("_qd2geo", "apxq2g", slice(None)),
347
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
348
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
349
                                 fslice):
350
        """Tests Apex/fortran interface consistency for array input."""
351
        # Get the Apex class method and the fortran function call
352
        apex_func = getattr(self.apex_out, apex_method)
353
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
354
355
        # Set up the input arrays
356
        ref_lat = np.array([0, 30, 60, 90])
357
        ref_alt = np.array([100, 200, 300, 400])
358
        self.in_lat = ref_lat.reshape(arr_shape)
359
        self.in_alt = ref_alt.reshape(arr_shape)
360
        apex_args = self.get_input_args(apex_method)
361
362
        # Get the Apex class results
363
        aret = apex_func(*apex_args)
364
365
        # Get the fortran function results
366
        flats = list()
367
        flons = list()
368
369
        for i, lat in enumerate(ref_lat):
370
            self.in_lat = lat
371
            self.in_alt = ref_alt[i]
372
            fortran_args = self.get_input_args(fortran_method)
373
            fret = fortran_func(*fortran_args)[fslice]
374
            flats.append(fret[0])
375
            flons.append(fret[1])
376
377
        flats = np.array(flats)
378
        flons = np.array(flons)
379
380
        # Evaluate results
381
        try:
382
            # This returned value is array of floats
383
            np.testing.assert_allclose(aret[0].astype(float),
384
                                       flats.reshape(arr_shape).astype(float))
385
            np.testing.assert_allclose(aret[1].astype(float),
386
                                       flons.reshape(arr_shape).astype(float))
387
        except ValueError:
388
            # This returned value is array of arrays
389
            alats = aret[0].reshape((4,))
390
            alons = aret[1].reshape((4,))
391
            for i, flat in enumerate(flats):
392
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
393
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
394
395
        return
396
397
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
398
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
399
    def test_geo2apexall_scalar(self, lat, lon):
400
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
401
        # Get the Apex and Fortran results
402
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
403
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
404
405
        # Evaluate each element in the results
406
        for aval, fval in zip(aret, fret):
407
            np.testing.assert_allclose(aval, fval)
408
409
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
410
    def test_geo2apexall_array(self, arr_shape):
411
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
412
        # Set the input
413
        self.in_lat = np.array([0, 30, 60, 90])
414
        self.in_alt = np.array([100, 200, 300, 400])
415
416
        # Get the Apex class results
417
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
418
                                          self.in_lon,
419
                                          self.in_alt.reshape(arr_shape))
420
421
        # For each lat/alt pair, get the Fortran results
422
        fret = list()
423
        for i, lat in enumerate(self.in_lat):
424
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
425
                                                    self.in_alt[i], 300, 1))
426
427
        # Cycle through all returned values
428
        for i, ret in enumerate(aret):
429
            try:
430
                # This returned value is array of floats
431
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
432
                                      fret[3][i]]).reshape(arr_shape)
433
                np.testing.assert_allclose(ret.astype(float),
434
                                           fret_test.astype(float))
435
            except ValueError:
436
                # This returned value is array of arrays
437
                ret = ret.reshape((4,))
438
                for j, single_fret in enumerate(fret):
439
                    np.testing.assert_allclose(ret[j], single_fret[i])
440
        return
441
442
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
443
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
444
    def test_convert_consistency(self, in_coord, out_coord):
445
        """Test the self-consistency of the Apex convert method."""
446
        if in_coord == out_coord:
447
            pytest.skip("Test not needed for same src and dest coordinates")
448
449
        # Define the method name
450
        method_name = "2".join([in_coord, out_coord])
451
452
        # Get the method and method inputs
453
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
454
        apex_args = self.get_input_args(method_name)
455
        apex_method = getattr(self.apex_out, method_name)
456
457
        # Define the slice needed to get equivalent output from the named method
458
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
459
460
        # Get output using convert and named method
461
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
462
                                            out_coord, **convert_kwargs)
463
        method_out = apex_method(*apex_args)[mslice]
464
465
        # Compare both outputs, should be identical
466
        np.testing.assert_allclose(convert_out, method_out)
467
        return
468
469
    @pytest.mark.parametrize("bound_lat", [90, -90])
470
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
471
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
472
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
473
        """Test the conversion at the latitude boundary, with allowed excess."""
474
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
475
476
        # Get the two outputs, slight tolerance outside of boundary allowed
477
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
478
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
479
480
        # Test the outputs
481
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
482
        return
483
484
    def test_convert_qd2apex_at_equator(self):
485
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
486
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
487
                                       height=320.0)
488
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
489
                                          dest='apex', height=320.0)
490
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
491
        return
492
493
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
494
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
495
    def test_convert_withnan(self, src, dest):
496
        """Test Apex.convert success with NaN input."""
497
        if src == dest:
498
            pytest.skip("Test not needed for same src and dest coordinates")
499
500
        num_nans = 5
501
        in_loc = np.arange(0, 10, dtype=float)
502
        in_loc[:num_nans] = np.nan
503
504
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
505
506
        for out in out_loc:
507
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
508
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
509
510
        return
511
512
    @pytest.mark.parametrize("bad_lat", [91, -91])
513
    def test_convert_invalid_lat(self, bad_lat):
514
        """Test convert raises ValueError for invalid latitudes."""
515
516
        with pytest.raises(ValueError) as verr:
517
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
518
519
        assert str(verr.value).find("must be in [-90, 90]") > 0
520
        return
521
522
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
523
                                        ("geo", "mlt")])
524
    def test_convert_invalid_transformation(self, coords):
525
        """Test raises NotImplementedError for bad coordinates."""
526
        if "mlt" in coords:
527
            estr = "datetime must be given for MLT calculations"
528
        else:
529
            estr = "Unknown coordinate transformation"
530
531
        with pytest.raises(ValueError) as verr:
532
            self.apex_out.convert(0, 0, *coords)
533
534
        assert str(verr).find(estr) >= 0
535
        return
536
537
    @pytest.mark.parametrize("method_name, out_comp",
538
                             [("geo2apex",
539
                               (55.94841766357422, 94.10684204101562)),
540
                              ("apex2geo",
541
                               (51.476322174072266, -66.22817993164062,
542
                                5.727287771151168e-06)),
543
                              ("geo2qd",
544
                               (56.531288146972656, 94.10684204101562)),
545
                              ("apex2qd", (60.498401178276744, 15.0)),
546
                              ("qd2apex", (59.49138097045895, 15.0))])
547
    def test_method_scalar_input(self, method_name, out_comp):
548
        """Test the user method against set values with scalars."""
549
        # Get the desired methods
550
        user_method = getattr(self.apex_out, method_name)
551
552
        # Get the user output
553
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
554
555
        # Evaluate the user output
556
        np.testing.assert_allclose(user_out, out_comp)
557
558
        for out_val in user_out:
559
            assert np.asarray(out_val).shape == (), "output is not a scalar"
560
        return
561
562
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
563
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
564
    @pytest.mark.parametrize("method_args, out_shape",
565
                             [([[60, 60], 15, 100], (2,)),
566
                              ([60, [15, 15], 100], (2,)),
567
                              ([60, 15, [100, 100]], (2,)),
568
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
569
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
570
                                    out_shape):
571
        """Test the user method with inputs that require some broadcasting."""
572
        if in_coord == out_coord:
573
            pytest.skip("Test not needed for same src and dest coordinates")
574
575
        # Get the desired methods
576
        method_name = "2".join([in_coord, out_coord])
577
        user_method = getattr(self.apex_out, method_name)
578
579
        # Get the user output
580
        user_out = user_method(*method_args)
581
582
        # Evaluate the user output
583
        for out_val in user_out:
584
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
585
            assert out_val.shape == out_shape
586
        return
587
588
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
589
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
590
    @pytest.mark.parametrize("bad_lat", [91, -91])
591
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
592
        """Test convert raises ValueError for invalid latitudes."""
593
        if in_coord == out_coord:
594
            pytest.skip("Test not needed for same src and dest coordinates")
595
596
        # Get the desired methods
597
        method_name = "2".join([in_coord, out_coord])
598
        user_method = getattr(self.apex_out, method_name)
599
600
        with pytest.raises(ValueError) as verr:
601
            user_method(bad_lat, 15, 100)
602
603
        assert str(verr.value).find("must be in [-90, 90]") > 0
604
        return
605
606
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
607
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
608
    @pytest.mark.parametrize("bound_lat", [90, -90])
609
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
610
        """Test user methods at the latitude boundary, with allowed excess."""
611
        if in_coord == out_coord:
612
            pytest.skip("Test not needed for same src and dest coordinates")
613
614
        # Get the desired methods
615
        method_name = "2".join([in_coord, out_coord])
616
        user_method = getattr(self.apex_out, method_name)
617
618
        # Get a latitude just beyond the limit
619
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
620
621
        # Get the two outputs, slight tolerance outside of boundary allowed
622
        bound_out = user_method(bound_lat, 0, 100)
623
        excess_out = user_method(excess_lat, 0, 100)
624
625
        # Test the outputs
626
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
627
        return
628
629
    def test_geo2apex_undefined_warning(self):
630
        """Test geo2apex warning and fill values for an undefined location."""
631
632
        # Update the apex object
633
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
634
635
        # Get the output and the warnings
636
        with warnings.catch_warnings(record=True) as warn_rec:
637
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
638
639
        assert np.isnan(user_lat)
640
        assert np.isfinite(user_lon)
641
        assert len(warn_rec) == 1
642
        assert issubclass(warn_rec[-1].category, UserWarning)
643
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
644
        return
645
646
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
647
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
648
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
649
        """Test quasi-dipole success with a height close to the reference."""
650
        qd_method = getattr(self.apex_out, method_name)
651
        in_args = [0, 15, self.apex_out.refh + delta_h]
652
        out_coords = qd_method(*in_args)
653
654
        for i, out_val in enumerate(out_coords):
655
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
656
        return
657
658
    @pytest.mark.parametrize("method_name, hinc, msg",
659
                             [("apex2qd", 1.0, "is > apex height"),
660
                              ("qd2apex", -1.0, "is < reference height")])
661
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
662
        """Quasi-dipole raises ApexHeightError when height above reference."""
663
        qd_method = getattr(self.apex_out, method_name)
664
665
        with pytest.raises(apexpy.ApexHeightError) as aerr:
666
            qd_method(0, 15, self.apex_out.refh + hinc)
667
668
        assert str(aerr).find(msg) > 0
669
        return
670
671
672
class TestApexMLTMethods(object):
673
    """Test the Apex Magnetic Local Time (MLT) methods."""
674
    def setup(self):
675
        """Initialize all tests."""
676
        self.apex_out = apexpy.Apex(date=2000, refh=300)
677
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
678
679
    def teardown(self):
680
        """Clean up after each test."""
681
        del self.apex_out, self.in_time
682
683
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
684
    def test_convert_to_mlt(self, in_coord):
685
        """Test the conversions to MLT using Apex convert."""
686
687
        # Get the magnetic longitude from the appropriate method
688
        if in_coord == "geo":
689
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
690
            mlon = apex_method(60, 15, 100)[1]
691
        else:
692
            mlon = 15
693
694
        # Get the output MLT values
695
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
696
                                            height=100, ssheight=2e5,
697
                                            datetime=self.in_time)[1]
698
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
699
700
        # Test the outputs
701
        np.testing.assert_allclose(convert_mlt, method_mlt)
702
        return
703
704
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
705
    def test_convert_mlt_to_lon(self, out_coord):
706
        """Test the conversions from MLT using Apex convert."""
707
        # Get the output longitudes
708
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
709
                                            height=100, ssheight=2e5,
710
                                            datetime=self.in_time,
711
                                            precision=1e-2)
712
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
713
714
        if out_coord == "geo":
715
            method_out = self.apex_out.apex2geo(60, mlon, 100,
716
                                                precision=1e-2)[:-1]
717
        elif out_coord == "qd":
718
            method_out = self.apex_out.apex2qd(60, mlon, 100)
719
        else:
720
            method_out = (60, mlon)
721
722
        # Evaluate the outputs
723
        np.testing.assert_allclose(convert_out, method_out)
724
        return
725
726
    def test_convert_geo2mlt_nodate(self):
727
        """Test convert from geo to MLT raises ValueError with no datetime."""
728
        with pytest.raises(ValueError):
729
            self.apex_out.convert(60, 15, 'geo', 'mlt')
730
        return
731
732
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
733
                             [({}, 23.019629923502603),
734
                              ({"ssheight": 100000}, 23.026712036132814)])
735
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
736
        """Test mlon2mlt with scalar inputs."""
737
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
738
739
        np.testing.assert_allclose(mlt, test_mlt)
740
        assert np.asarray(mlt).shape == ()
741
        return
742
743
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
744
                             [({}, 14.705535888671875),
745
                              ({"ssheight": 100000}, 14.599319458007812)])
746
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
747
        """Test mlt2mlon with scalar inputs."""
748
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
749
750
        np.testing.assert_allclose(mlon, test_mlon)
751
        assert np.asarray(mlon).shape == ()
752
        return
753
754
    @pytest.mark.parametrize("mlon,test_mlt",
755
                             [([0, 180], [23.019261, 11.019261]),
756
                              (np.array([0, 180]), [23.019261, 11.019261]),
757
                              (np.array([[0], [180]]),
758
                               np.array([[23.019261], [11.019261]])),
759
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
760
                                                      [23.019261, 11.019261]]),
761
                              (range(0, 361, 30),
762
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
763
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
764
                                19.01963, 21.01963, 23.01963])])
765
    def test_mlon2mlt_array(self, mlon, test_mlt):
766
        """Test mlon2mlt with array inputs."""
767
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
768
769
        assert mlt.shape == np.asarray(test_mlt).shape
770
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
771
        return
772
773
    @pytest.mark.parametrize("mlt,test_mlon",
774
                             [([0, 12], [14.705551, 194.705551]),
775
                              (np.array([0, 12]), [14.705551, 194.705551]),
776
                              (np.array([[0], [12]]),
777
                               np.array([[14.705551], [194.705551]])),
778
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
779
                                                    [14.705551, 194.705551]]),
780
                              (range(0, 25, 2),
781
                               [14.705551, 44.705551, 74.705551, 104.705551,
782
                                134.705551, 164.705551, 194.705551, 224.705551,
783
                                254.705551, 284.705551, 314.705551, 344.705551,
784
                                14.705551])])
785
    def test_mlt2mlon_array(self, mlt, test_mlon):
786
        """Test mlt2mlon with array inputs."""
787
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
788
789
        assert mlon.shape == np.asarray(test_mlon).shape
790
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
791
        return
792
793
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
794
    def test_mlon2mlt_diffdates(self, method_name):
795
        """Test that MLT varies with universal time."""
796
        apex_method = getattr(self.apex_out, method_name)
797
        mlt1 = apex_method(0, self.in_time)
798
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
799
800
        assert mlt1 != mlt2
801
        return
802
803
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
804
    def test_mlon2mlt_offset(self, mlt_offset):
805
        """Test the time wrapping logic for the MLT."""
806
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
807
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
808
                                      self.in_time) + mlt_offset
809
810
        np.testing.assert_allclose(mlt1, mlt2)
811
        return
812
813
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
814
    def test_mlt2mlon_offset(self, mlon_offset):
815
        """Test the time wrapping logic for the magnetic longitude."""
816
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
817
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
818
                                       self.in_time) - mlon_offset
819
820
        np.testing.assert_allclose(mlon1, mlon2)
821
        return
822
823
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
824
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
825
    def test_convert_and_return(self, order, start_val):
826
        """Test the conversion to magnetic longitude or MLT and back again."""
827
        first_method = getattr(self.apex_out, "2".join(order))
828
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
829
830
        middle_val = first_method(start_val, self.in_time)
831
        end_val = second_method(middle_val, self.in_time)
832
833
        np.testing.assert_allclose(start_val, end_val)
834
        return
835
836
837
class TestApexMapMethods(object):
838
    """Test the Apex height mapping methods."""
839
    def setup(self):
840
        """Initialize all tests."""
841
        self.apex_out = apexpy.Apex(date=2000, refh=300)
842
843
    def teardown(self):
844
        """Clean up after each test."""
845
        del self.apex_out
846
847
    @pytest.mark.parametrize("in_args,test_mapped",
848
                             [([60, 15, 100, 10000],
849
                               [31.841466903686523, 17.916635513305664,
850
                                1.7075473124350538e-6]),
851
                              ([30, 170, 100, 500, False, 1e-2],
852
                               [25.727270126342773, 169.60546875,
853
                                0.00017573432705830783]),
854
                              ([60, 15, 100, 10000, True],
855
                               [-25.424888610839844, 27.310426712036133,
856
                                1.2074182222931995e-6]),
857
                              ([30, 170, 100, 500, True, 1e-2],
858
                               [-13.76642894744873, 164.24259948730469,
859
                                0.00056820799363777041])])
860
    def test_map_to_height(self, in_args, test_mapped):
861
        """Test the map_to_height function."""
862
        mapped = self.apex_out.map_to_height(*in_args)
863
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
864
        return
865
866
    def test_map_to_height_same_height(self):
867
        """Test the map_to_height function when mapping to same height."""
868
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
869
                                             precision=1e-10)
870
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
871
                                   rtol=1e-5)
872
        return
873
874
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
875
    @pytest.mark.parametrize('ivec', range(0, 4))
876
    def test_map_to_height_array_location(self, arr_shape, ivec):
877
        """Test map_to_height with array input."""
878
        # Set the base input and output values
879
        in_args = [60, 15, 100, 100]
880
        test_mapped = [60, 15.00000381, 0.0]
881
882
        # Update inputs for one vectorized value
883
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
884
885
        # Calculate and test function
886
        mapped = self.apex_out.map_to_height(*in_args)
887
        for i, test_val in enumerate(test_mapped):
888
            assert mapped[i].shape == arr_shape
889
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
890
        return
891
892
    @pytest.mark.parametrize("method_name,in_args",
893
                             [("map_to_height", [0, 15, 100, 10000]),
894
                              ("map_E_to_height",
895
                               [0, 15, 100, 10000, [1, 2, 3]]),
896
                              ("map_V_to_height",
897
                               [0, 15, 100, 10000, [1, 2, 3]])])
898
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
899
        """Test map_to_height raises ApexHeightError."""
900
        apex_method = getattr(self.apex_out, method_name)
901
902
        with pytest.raises(apexpy.ApexHeightError) as aerr:
903
            apex_method(*in_args)
904
905
        assert aerr.match("is > apex height")
906
        return
907
908
    @pytest.mark.parametrize("method_name",
909
                             ["map_E_to_height", "map_V_to_height"])
910
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
911
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
912
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
913
        """Test height mapping of E/V with baddly shaped input raises Error."""
914
        apex_method = getattr(self.apex_out, method_name)
915
        in_args = [60, 15, 100, 500, ev_input]
916
        with pytest.raises(ValueError) as verr:
917
            apex_method(*in_args)
918
919
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
920
        return
921
922
    def test_mapping_EV_bad_flag(self):
923
        """Test _map_EV_to_height raises error for bad data type flag."""
924
        with pytest.raises(ValueError) as verr:
925
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
926
927
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
928
        return
929
930
    @pytest.mark.parametrize("in_args,test_mapped",
931
                             [([60, 15, 100, 500, [1, 2, 3]],
932
                               [0.71152183, 2.35624876, 0.57260784]),
933
                              ([60, 15, 100, 500, [2, 3, 4]],
934
                               [1.56028502, 3.43916636, 0.78235384]),
935
                              ([60, 15, 100, 1000, [1, 2, 3]],
936
                               [0.67796492, 2.08982134, 0.55860785]),
937
                              ([60, 15, 200, 500, [1, 2, 3]],
938
                               [0.72377397, 2.42737471, 0.59083726]),
939
                              ([60, 30, 100, 500, [1, 2, 3]],
940
                               [0.68626344, 2.37530133, 0.60060124]),
941
                              ([70, 15, 100, 500, [1, 2, 3]],
942
                               [0.72760378, 2.18082305, 0.29141979])])
943
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
944
        """Test mapping of E-field to a specified height."""
945
        mapped = self.apex_out.map_E_to_height(*in_args)
946
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
947
        return
948
949
    @pytest.mark.parametrize('ev_flag, test_mapped',
950
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
951
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
952
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
953
    @pytest.mark.parametrize('ivec', range(0, 5))
954
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
955
                                             arr_shape, ivec):
956
        """Test mapping of E-field/drift to a specified height with arrays."""
957
        # Set the base input and output values
958
        eshape = list(arr_shape)
959
        eshape.insert(0, 3)
960
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
961
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
962
963
        # Update inputs for one vectorized value if this is a location input
964
        if ivec < 4:
965
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
966
967
        # Get the mapped output
968
        apex_method = getattr(self.apex_out,
969
                              "map_{:s}_to_height".format(ev_flag))
970
        mapped = apex_method(*in_args)
971
972
        # Test the results
973
        for i, test_val in enumerate(test_mapped):
974
            assert mapped[i].shape == arr_shape
975
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
976
        return
977
978
    @pytest.mark.parametrize("in_args,test_mapped",
979
                             [([60, 15, 100, 500, [1, 2, 3]],
980
                               [0.81971957, 2.84512495, 0.69545001]),
981
                              ([60, 15, 100, 500, [2, 3, 4]],
982
                               [1.83027746, 4.14346436, 0.94764179]),
983
                              ([60, 15, 100, 1000, [1, 2, 3]],
984
                               [0.92457698, 3.14997661, 0.85135187]),
985
                              ([60, 15, 200, 500, [1, 2, 3]],
986
                               [0.80388262, 2.79321504, 0.68285158]),
987
                              ([60, 30, 100, 500, [1, 2, 3]],
988
                               [0.76141245, 2.87884673, 0.73655941]),
989
                              ([70, 15, 100, 500, [1, 2, 3]],
990
                               [0.84681866, 2.5925821, 0.34792655])])
991
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
992
        """Test mapping of velocity to a specified height."""
993
        mapped = self.apex_out.map_V_to_height(*in_args)
994
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
995
        return
996
997
998
class TestApexBasevectorMethods(object):
999
    """Test the Apex height base vector methods."""
1000
    def setup(self):
1001
        """Initialize all tests."""
1002
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1003
        self.lat = 60
1004
        self.lon = 15
1005
        self.height = 100
1006
        self.test_basevec = None
1007
1008
    def teardown(self):
1009
        """Clean up after each test."""
1010
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1011
1012
    def get_comparison_results(self, bv_coord, coords, precision):
1013
        """Get the base vector results using the hidden function for comparison.
1014
1015
        Parameters
1016
        ----------
1017
        bv_coord : str
1018
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1019
            or 'bvectors_apex'
1020
        coords : str
1021
            Expects one of 'geo', 'apex', or 'qd'
1022
        precision : float
1023
            Float specifiying precision
1024
1025
        """
1026
        if coords == "geo":
1027
            glat = self.lat
1028
            glon = self.lon
1029
        else:
1030
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1031
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1032
                                        precision=precision)
1033
1034
        if bv_coord == 'qd':
1035
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1036
        elif bv_coord == 'apex':
1037
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1038
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1039
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1040
        else:
1041
            # These are set results that need to be updated with IGRF
1042
            if coords == "geo":
1043
                self.test_basevec = (
1044
                    np.array([4.42368795e-05, 4.42368795e-05]),
1045
                    np.array([[0.01047826, 0.01047826],
1046
                              [0.33089194, 0.33089194],
1047
                              [-1.04941, -1.04941]]),
1048
                    np.array([5.3564698e-05, 5.3564698e-05]),
1049
                    np.array([[0.00865356, 0.00865356],
1050
                              [0.27327004, 0.27327004],
1051
                              [-0.8666646, -0.8666646]]))
1052
            elif coords == "apex":
1053
                self.test_basevec = (
1054
                    np.array([4.48672735e-05, 4.48672735e-05]),
1055
                    np.array([[-0.12510721, -0.12510721],
1056
                              [0.28945938, 0.28945938],
1057
                              [-1.1505738, -1.1505738]]),
1058
                    np.array([6.38577444e-05, 6.38577444e-05]),
1059
                    np.array([[-0.08790194, -0.08790194],
1060
                              [0.2033779, 0.2033779],
1061
                              [-0.808408, -0.808408]]))
1062
            else:
1063
                self.test_basevec = (
1064
                    np.array([4.46348578e-05, 4.46348578e-05]),
1065
                    np.array([[-0.12642345, -0.12642345],
1066
                              [0.29695055, 0.29695055],
1067
                              [-1.1517885, -1.1517885]]),
1068
                    np.array([6.38626285e-05, 6.38626285e-05]),
1069
                    np.array([[-0.08835986, -0.08835986],
1070
                              [0.20754464, 0.20754464],
1071
                              [-0.8050078, -0.8050078]]))
1072
1073
        return
1074
1075
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1076
    @pytest.mark.parametrize("coords,precision",
1077
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1078
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1079
        """Test the base vector calculations with scalars."""
1080
        # Get the base vectors
1081
        base_method = getattr(self.apex_out,
1082
                              "basevectors_{:s}".format(bv_coord))
1083
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1084
                              precision=precision)
1085
        self.get_comparison_results(bv_coord, coords, precision)
1086
        if bv_coord == "apex":
1087
            basevec = list(basevec)
1088
            for i in range(4):
1089
                # Not able to compare indices 2, 3, 4, and 5
1090
                basevec.pop(2)
1091
1092
        # Test the results
1093
        for i, vec in enumerate(basevec):
1094
            np.testing.assert_allclose(vec, self.test_basevec[i])
1095
        return
1096
1097
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1098
    def test_basevectors_scalar_shape(self, bv_coord):
1099
        """Test the shape of the scalar output."""
1100
        base_method = getattr(self.apex_out,
1101
                              "basevectors_{:s}".format(bv_coord))
1102
        basevec = base_method(self.lat, self.lon, self.height)
1103
1104
        for i, vec in enumerate(basevec):
1105
            if i < 2:
1106
                assert vec.shape == (2,)
1107
            else:
1108
                assert vec.shape == (3,)
1109
        return
1110
1111
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1112
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1113
    @pytest.mark.parametrize("ivec", range(3))
1114
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1115
        """Test the output shape for array inputs."""
1116
        # Define the input arguments
1117
        in_args = [self.lat, self.lon, self.height]
1118
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1119
1120
        # Get the basevectors
1121
        base_method = getattr(self.apex_out,
1122
                              "basevectors_{:s}".format(bv_coord))
1123
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1124
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1125
        if bv_coord == "apex":
1126
            basevec = list(basevec)
1127
            for i in range(4):
1128
                # Not able to compare indices 2, 3, 4, and 5
1129
                basevec.pop(2)
1130
1131
        # Evaluate the shape and the values
1132
        for i, vec in enumerate(basevec):
1133
            test_shape = list(arr_shape)
1134
            test_shape.insert(0, 2 if i < 2 else 3)
1135
            assert vec.shape == tuple(test_shape)
1136
            assert np.all(self.test_basevec[i][0] == vec[0])
1137
            assert np.all(self.test_basevec[i][1] == vec[1])
1138
        return
1139
1140
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1141
    def test_bvectors_apex(self, coords):
1142
        """Test the bvectors_apex method."""
1143
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1144
                   [self.height, self.height]]
1145
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1146
1147
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1148
                                              precision=1e-10)
1149
        for i, vec in enumerate(basevec):
1150
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1151
                                                 decimal=5)
1152
        return
1153
1154
    def test_basevectors_apex_extra_values(self):
1155
        """Test specific values in the apex base vector output."""
1156
        # Set the testing arrays
1157
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1158
                             np.array([0.939012, 0.073416, -0.07342]),
1159
                             np.array([0.055389, 1.004155, 0.257594]),
1160
                             np.array([0, 0, 1.065135])]
1161
1162
        # Get the desired output
1163
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1164
1165
        # Test the values not covered by `test_basevectors_scalar`
1166
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1167
            np.testing.assert_allclose(basevec[ibase],
1168
                                       self.test_basevec[itest], rtol=1e-4)
1169
        return
1170
1171
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1172
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1173
    def test_basevectors_apex_delta(self, lat, lon):
1174
        """Test that vectors are calculated correctly."""
1175
        # Get the apex base vectors and sort them for easy testing
1176
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1177
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1178
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1179
        gvec = [g1, g2, g3]
1180
        dvec = [d1, d2, d3]
1181
        evec = [e1, e2, e3]
1182
1183
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1184
            delta = 1 if idelta == jdelta else 0
1185
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1186
                                       delta, rtol=0, atol=1e-5)
1187
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1188
                                       delta, rtol=0, atol=1e-5)
1189
        return
1190
1191
    def test_basevectors_apex_invalid_scalar(self):
1192
        """Test warning and fill values for base vectors with bad inputs."""
1193
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1194
        invalid = np.full(shape=(3,), fill_value=np.nan)
1195
1196
        # Get the output and the warnings
1197
        with warnings.catch_warnings(record=True) as warn_rec:
1198
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1199
1200
        for i, bvec in enumerate(basevec):
1201
            if i < 2:
1202
                assert not np.allclose(bvec, invalid[:2])
1203
            else:
1204
                np.testing.assert_allclose(bvec, invalid)
1205
1206
        assert issubclass(warn_rec[-1].category, UserWarning)
1207
        assert 'set to NaN where' in str(warn_rec[-1].message)
1208
        return
1209
1210
1211
class TestApexGetMethods(object):
1212
    """Test the Apex `get` methods."""
1213
    def setup(self):
1214
        """Initialize all tests."""
1215
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1216
1217
    def teardown(self):
1218
        """Clean up after each test."""
1219
        del self.apex_out
1220
1221
    @pytest.mark.parametrize("alat, aheight",
1222
                             [(10, 507.409702543805),
1223
                              (60, 20313.026999999987),
1224
                              ([10, 60],
1225
                               [507.409702543805, 20313.026999999987]),
1226
                              ([[10], [60]],
1227
                               [[507.409702543805], [20313.026999999987]])])
1228
    def test_get_apex(self, alat, aheight):
1229
        """Test the apex height retrieval results."""
1230
        alt = self.apex_out.get_apex(alat)
1231
        np.testing.assert_allclose(alt, aheight)
1232
        return
1233
1234
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1235
                             [([80], [100], [300], 5.100682377815247e-05),
1236
                              ([80, 80], [100], [300],
1237
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1238
                              ([[80], [80]], [100], [300],
1239
                               [[5.100682377815247e-05],
1240
                                [5.100682377815247e-05]]),
1241
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1242
                               np.array([4.18657154e-05, 5.11118114e-05,
1243
                                         4.91969854e-05, 5.10519207e-05,
1244
                                         4.90054816e-05])),
1245
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1246
    def test_get_babs(self, glat, glon, height, test_bmag):
1247
        """Test the method to get the magnitude of the magnetic field."""
1248
        bmag = self.apex_out.get_babs(glat, glon, height)
1249
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1250
        return
1251
1252
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1253
    def test_get_apex_with_invalid_lat(self, bad_lat):
1254
        """Test get methods raise ValueError for invalid latitudes."""
1255
1256
        with pytest.raises(ValueError) as verr:
1257
            self.apex_out.get_apex(bad_lat)
1258
1259
        assert str(verr.value).find("must be in [-90, 90]") > 0
1260
        return
1261
1262
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1263
    def test_get_babs_with_invalid_lat(self, bad_lat):
1264
        """Test get methods raise ValueError for invalid latitudes."""
1265
1266
        with pytest.raises(ValueError) as verr:
1267
            self.apex_out.get_babs(bad_lat, 15, 100)
1268
1269
        assert str(verr.value).find("must be in [-90, 90]") > 0
1270
        return
1271
1272
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1273
    def test_get_at_lat_boundary(self, bound_lat):
1274
        """Test get methods at the latitude boundary, with allowed excess."""
1275
        # Get a latitude just beyond the limit
1276
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1277
1278
        # Get the two outputs, slight tolerance outside of boundary allowed
1279
        bound_out = self.apex_out.get_apex(bound_lat)
1280
        excess_out = self.apex_out.get_apex(excess_lat)
1281
1282
        # Test the outputs
1283
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1284
        return
1285
1286
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1287
    def test_get_height_at_equator(self, apex_height):
1288
        """Test that `get_height` returns apex height at equator.
1289
1290
        Parameters
1291
        ----------
1292
        apex_height : float
1293
            Apex height
1294
1295
        """
1296
1297
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1298
        return
1299
1300
    @pytest.mark.parametrize("lat, height", [
1301
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1302
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1303
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1304
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1305
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1306
        (30, 657.2477500000014), (40, -871.8751821247979),
1307
        (50, -2499.1338178752017), (60, -4028.256749999999),
1308
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1309
    def test_get_height_along_fieldline(self, lat, height):
1310
        """Test that `get_height` returns expected height of field line.
1311
1312
        Parameters
1313
        ----------
1314
        lat : float
1315
            Input latitude
1316
        height : float
1317
            Output field-line height for line with apex of 3000 km
1318
1319
        """
1320
1321
        fheight = self.apex_out.get_height(lat, 3000.0)
1322
        assert abs(height - fheight) < 1.0e-7, \
1323
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1324
        return
1325