Passed
Pull Request — main (#103)
by Angeline
01:39
created

test_Apex.igrf_file()   B

Complexity

Conditions 5

Size

Total Lines 33
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 1
dl 0
loc 33
rs 8.9332
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    def setup_method(self):
72
        self.apex_out = None
73
        self.test_date = dt.datetime.utcnow()
74
        self.test_refh = 0
75
        self.bad_file = 'foo/path/to/datafile.blah'
76
77
    def teardown_method(self):
78
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
79
80
    def eval_date(self):
81
        """Evaluate the times in self.test_date and self.apex_out."""
82
        if isinstance(self.test_date, dt.datetime) \
83
           or isinstance(self.test_date, dt.date):
84
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
85
86
        # Assert the times are the same on the order of tens of seconds.
87
        # Necessary to evaluate the current UTC
88
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
89
        return
90
91
    def eval_refh(self):
92
        """Evaluate the reference height in self.refh and self.apex_out."""
93
        eval_str = "".join(["expected reference height [",
94
                            "{:}] not equal to Apex ".format(self.test_refh),
95
                            "reference height ",
96
                            "[{:}]".format(self.apex_out.refh)])
97
        assert self.test_refh == self.apex_out.refh, eval_str
98
        return
99
100
    def test_init_defaults(self):
101
        """Test Apex class default initialization."""
102
        self.apex_out = apexpy.Apex()
103
        self.eval_date()
104
        self.eval_refh()
105
        return
106
107
    @pytest.mark.parametrize("in_date",
108
                             [2015, 2015.5, dt.date(2015, 1, 1),
109
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
110
    def test_init_date(self, in_date):
111
        """Test Apex class with date initialization."""
112
        self.test_date = in_date
113
        self.apex_out = apexpy.Apex(date=self.test_date)
114
        self.eval_date()
115
        self.eval_refh()
116
        return
117
118
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
119
    def test_set_epoch(self, new_date):
120
        """Test successful setting of Apex epoch after initialization."""
121
        # Evaluate the default initialization
122
        self.apex_out = apexpy.Apex()
123
        self.eval_date()
124
        self.eval_refh()
125
126
        # Update the epoch
127
        ref_apex = eval(self.apex_out.__repr__())
128
        self.apex_out.set_epoch(new_date)
129
        assert ref_apex != self.apex_out
130
        self.test_date = new_date
131
        self.eval_date()
132
        return
133
134
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
135
    def test_init_refh(self, in_refh):
136
        """Test Apex class with reference height initialization."""
137
        self.test_refh = in_refh
138
        self.apex_out = apexpy.Apex(refh=self.test_refh)
139
        self.eval_date()
140
        self.eval_refh()
141
        return
142
143
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
144
    def test_set_refh(self, new_refh):
145
        """Test the method used to set the reference height after the init."""
146
        # Verify the defaults are set
147
        self.apex_out = apexpy.Apex(date=self.test_date)
148
        self.eval_date()
149
        self.eval_refh()
150
151
        # Update to a new reference height and test
152
        ref_apex = eval(self.apex_out.__repr__())
153
        self.apex_out.set_refh(new_refh)
154
155
        if self.test_refh == new_refh:
156
            assert ref_apex == self.apex_out
157
        else:
158
            assert ref_apex != self.apex_out
159
            self.test_refh = new_refh
160
        self.eval_refh()
161
        return
162
163
    def test_init_with_bad_datafile(self):
164
        """Test raises IOError with non-existent datafile input."""
165
        with pytest.raises(IOError) as oerr:
166
            apexpy.Apex(datafile=self.bad_file)
167
        assert str(oerr.value).startswith('Data file does not exist')
168
        return
169
170
    def test_init_with_bad_fortranlib(self):
171
        """Test raises IOError with non-existent datafile input."""
172
        with pytest.raises(IOError) as oerr:
173
            apexpy.Apex(fortranlib=self.bad_file)
174
        assert str(oerr.value).startswith('Fortran library does not exist')
175
        return
176
177
    def test_repr_eval(self):
178
        """Test the Apex.__repr__ results."""
179
        # Initialize the apex object
180
        self.apex_out = apexpy.Apex()
181
        self.eval_date()
182
        self.eval_refh()
183
184
        # Get and test the repr string
185
        out_str = self.apex_out.__repr__()
186
        assert out_str.find("apexpy.Apex(") == 0
187
188
        # Test the ability to re-create the apex object from the repr string
189
        new_apex = eval(out_str)
190
        assert new_apex == self.apex_out
191
        return
192
193
    def test_ne_other_class(self):
194
        """Test Apex class inequality to a different class."""
195
        self.apex_out = apexpy.Apex()
196
        self.eval_date()
197
        self.eval_refh()
198
199
        assert self.apex_out != self.test_date
200
        return
201
202
    def test_ne_missing_attr(self):
203
        """Test Apex class inequality when attributes are missing from one."""
204
        self.apex_out = apexpy.Apex()
205
        self.eval_date()
206
        self.eval_refh()
207
        ref_apex = eval(self.apex_out.__repr__())
208
        del ref_apex.RE
209
210
        assert ref_apex != self.apex_out
211
        assert self.apex_out != ref_apex
212
        return
213
214
    def test_eq_missing_attr(self):
215
        """Test Apex class equality when attributes are missing from both."""
216
        self.apex_out = apexpy.Apex()
217
        self.eval_date()
218
        self.eval_refh()
219
        ref_apex = eval(self.apex_out.__repr__())
220
        del ref_apex.RE, self.apex_out.RE
221
222
        assert ref_apex == self.apex_out
223
        return
224
225
    def test_str_eval(self):
226
        """Test the Apex.__str__ results."""
227
        # Initialize the apex object
228
        self.apex_out = apexpy.Apex()
229
        self.eval_date()
230
        self.eval_refh()
231
232
        # Get and test the printed string
233
        out_str = self.apex_out.__str__()
234
        assert out_str.find("Decimal year") > 0
235
        return
236
237
238
class TestApexMethod(object):
239
    """Test the Apex methods."""
240
    def setup_method(self):
241
        """Initialize all tests."""
242
        self.apex_out = apexpy.Apex(date=2000, refh=300)
243
        self.in_lat = 60
244
        self.in_lon = 15
245
        self.in_alt = 100
246
247
    def teardown_method(self):
248
        """Clean up after each test."""
249
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
250
251
    def get_input_args(self, method_name, precision=0.0):
252
        """Set the input arguments for the different Apex methods.
253
254
        Parameters
255
        ----------
256
        method_name : str
257
            Name of the Apex class method
258
        precision : float
259
            Value for the precision (default=0.0)
260
261
        Returns
262
        -------
263
        in_args : list
264
            List of the appropriate input arguments
265
266
        """
267
        in_args = [self.in_lat, self.in_lon, self.in_alt]
268
269
        # Add precision, if needed
270
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
271
                           "_apex2geo"]:
272
            in_args.append(precision)
273
274
        # Add a reference height, if needed
275
        if method_name in ["apxg2all"]:
276
            in_args.append(300)
277
278
        # Add a vector flag, if needed
279
        if method_name in ["apxg2all", "apxg2q"]:
280
            in_args.append(1)
281
282
        return in_args
283
284
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
285
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
286
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
287
                              ("_qd2geo", "apxq2g", slice(None)),
288
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
289
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
290
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
291
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
292
                                  lat, lon):
293
        """Tests Apex/fortran interface consistency for scalars."""
294
        # Set the input coordinates
295
        self.in_lat = lat
296
        self.in_lon = lon
297
298
        # Get the Apex class method and the fortran function call
299
        apex_func = getattr(self.apex_out, apex_method)
300
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
301
302
        # Get the appropriate input arguments
303
        apex_args = self.get_input_args(apex_method)
304
        fortran_args = self.get_input_args(fortran_method)
305
306
        # Evaluate the equivalent function calls
307
        np.testing.assert_allclose(apex_func(*apex_args),
308
                                   fortran_func(*fortran_args)[fslice])
309
        return
310
311
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
312
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
313
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
314
                              ("_qd2geo", "apxq2g", slice(None)),
315
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
316
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
317
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
318
                                           (180, -180), (-180, 180),
319
                                           (-345, 15), (375, 15)])
320
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
321
                                        fslice, lat, lon1, lon2):
322
        """Tests Apex/fortran interface consistency for longitude rollover."""
323
        # Set the fixed input coordinate
324
        self.in_lat = lat
325
326
        # Get the Apex class method and the fortran function call
327
        apex_func = getattr(self.apex_out, apex_method)
328
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
329
330
        # Get the appropriate input arguments
331
        self.in_lon = lon1
332
        apex_args = self.get_input_args(apex_method)
333
334
        self.in_lon = lon2
335
        fortran_args = self.get_input_args(fortran_method)
336
337
        # Evaluate the equivalent function calls
338
        np.testing.assert_allclose(apex_func(*apex_args),
339
                                   fortran_func(*fortran_args)[fslice])
340
        return
341
342
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
343
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
344
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
345
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
346
                              ("_qd2geo", "apxq2g", slice(None)),
347
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
348
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
349
                                 fslice):
350
        """Tests Apex/fortran interface consistency for array input."""
351
        # Get the Apex class method and the fortran function call
352
        apex_func = getattr(self.apex_out, apex_method)
353
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
354
355
        # Set up the input arrays
356
        ref_lat = np.array([0, 30, 60, 90])
357
        ref_alt = np.array([100, 200, 300, 400])
358
        self.in_lat = ref_lat.reshape(arr_shape)
359
        self.in_alt = ref_alt.reshape(arr_shape)
360
        apex_args = self.get_input_args(apex_method)
361
362
        # Get the Apex class results
363
        aret = apex_func(*apex_args)
364
365
        # Get the fortran function results
366
        flats = list()
367
        flons = list()
368
369
        for i, lat in enumerate(ref_lat):
370
            self.in_lat = lat
371
            self.in_alt = ref_alt[i]
372
            fortran_args = self.get_input_args(fortran_method)
373
            fret = fortran_func(*fortran_args)[fslice]
374
            flats.append(fret[0])
375
            flons.append(fret[1])
376
377
        flats = np.array(flats)
378
        flons = np.array(flons)
379
380
        # Evaluate results
381
        try:
382
            # This returned value is array of floats
383
            np.testing.assert_allclose(aret[0].astype(float),
384
                                       flats.reshape(arr_shape).astype(float))
385
            np.testing.assert_allclose(aret[1].astype(float),
386
                                       flons.reshape(arr_shape).astype(float))
387
        except ValueError:
388
            # This returned value is array of arrays
389
            alats = aret[0].reshape((4,))
390
            alons = aret[1].reshape((4,))
391
            for i, flat in enumerate(flats):
392
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
393
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
394
395
        return
396
397
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
398
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
399
    def test_geo2apexall_scalar(self, lat, lon):
400
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
401
        # Get the Apex and Fortran results
402
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
403
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
404
405
        # Evaluate each element in the results
406
        for aval, fval in zip(aret, fret):
407
            np.testing.assert_allclose(aval, fval)
408
409
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
410
    def test_geo2apexall_array(self, arr_shape):
411
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
412
        # Set the input
413
        self.in_lat = np.array([0, 30, 60, 90])
414
        self.in_alt = np.array([100, 200, 300, 400])
415
416
        # Get the Apex class results
417
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
418
                                          self.in_lon,
419
                                          self.in_alt.reshape(arr_shape))
420
421
        # For each lat/alt pair, get the Fortran results
422
        fret = list()
423
        for i, lat in enumerate(self.in_lat):
424
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
425
                                                    self.in_alt[i], 300, 1))
426
427
        # Cycle through all returned values
428
        for i, ret in enumerate(aret):
429
            try:
430
                # This returned value is array of floats
431
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
432
                                      fret[3][i]]).reshape(arr_shape)
433
                np.testing.assert_allclose(ret.astype(float),
434
                                           fret_test.astype(float))
435
            except ValueError:
436
                # This returned value is array of arrays
437
                ret = ret.reshape((4,))
438
                for j, single_fret in enumerate(fret):
439
                    np.testing.assert_allclose(ret[j], single_fret[i])
440
        return
441
442
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
443
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
444
    def test_convert_consistency(self, in_coord, out_coord):
445
        """Test the self-consistency of the Apex convert method."""
446
        if in_coord == out_coord:
447
            pytest.skip("Test not needed for same src and dest coordinates")
448
449
        # Define the method name
450
        method_name = "2".join([in_coord, out_coord])
451
452
        # Get the method and method inputs
453
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
454
        apex_args = self.get_input_args(method_name)
455
        apex_method = getattr(self.apex_out, method_name)
456
457
        # Define the slice needed to get equivalent output from the named method
458
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
459
460
        # Get output using convert and named method
461
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
462
                                            out_coord, **convert_kwargs)
463
        method_out = apex_method(*apex_args)[mslice]
464
465
        # Compare both outputs, should be identical
466
        np.testing.assert_allclose(convert_out, method_out)
467
        return
468
469
    @pytest.mark.parametrize("bound_lat", [90, -90])
470
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
471
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
472
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
473
        """Test the conversion at the latitude boundary, with allowed excess."""
474
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
475
476
        # Get the two outputs, slight tolerance outside of boundary allowed
477
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
478
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
479
480
        # Test the outputs
481
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
482
        return
483
484
    def test_convert_qd2apex_at_equator(self):
485
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
486
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
487
                                       height=320.0)
488
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
489
                                          dest='apex', height=320.0)
490
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
491
        return
492
493
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
494
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
495
    def test_convert_withnan(self, src, dest):
496
        """Test Apex.convert success with NaN input."""
497
        if src == dest:
498
            pytest.skip("Test not needed for same src and dest coordinates")
499
500
        num_nans = 5
501
        in_loc = np.arange(0, 10, dtype=float)
502
        in_loc[:num_nans] = np.nan
503
504
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
505
506
        for out in out_loc:
507
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
508
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
509
510
        return
511
512
    @pytest.mark.parametrize("bad_lat", [91, -91])
513
    def test_convert_invalid_lat(self, bad_lat):
514
        """Test convert raises ValueError for invalid latitudes."""
515
516
        with pytest.raises(ValueError) as verr:
517
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
518
519
        assert str(verr.value).find("must be in [-90, 90]") > 0
520
        return
521
522
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
523
                                        ("geo", "mlt")])
524
    def test_convert_invalid_transformation(self, coords):
525
        """Test raises NotImplementedError for bad coordinates."""
526
        if "mlt" in coords:
527
            estr = "datetime must be given for MLT calculations"
528
        else:
529
            estr = "Unknown coordinate transformation"
530
531
        with pytest.raises(ValueError) as verr:
532
            self.apex_out.convert(0, 0, *coords)
533
534
        assert str(verr).find(estr) >= 0
535
        return
536
537
    @pytest.mark.parametrize("method_name, out_comp",
538
                             [("geo2apex",
539
                               (55.94841766357422, 94.10684204101562)),
540
                              ("apex2geo",
541
                               (51.476322174072266, -66.22817993164062,
542
                                5.727287771151168e-06)),
543
                              ("geo2qd",
544
                               (56.531288146972656, 94.10684204101562)),
545
                              ("apex2qd", (60.498401178276744, 15.0)),
546
                              ("qd2apex", (59.49138097045895, 15.0))])
547
    def test_method_scalar_input(self, method_name, out_comp):
548
        """Test the user method against set values with scalars."""
549
        # Get the desired methods
550
        user_method = getattr(self.apex_out, method_name)
551
552
        # Get the user output
553
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
554
555
        # Evaluate the user output
556
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
557
558
        for out_val in user_out:
559
            assert np.asarray(out_val).shape == (), "output is not a scalar"
560
        return
561
562
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
563
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
564
    @pytest.mark.parametrize("method_args, out_shape",
565
                             [([[60, 60], 15, 100], (2,)),
566
                              ([60, [15, 15], 100], (2,)),
567
                              ([60, 15, [100, 100]], (2,)),
568
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
569
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
570
                                    out_shape):
571
        """Test the user method with inputs that require some broadcasting."""
572
        if in_coord == out_coord:
573
            pytest.skip("Test not needed for same src and dest coordinates")
574
575
        # Get the desired methods
576
        method_name = "2".join([in_coord, out_coord])
577
        user_method = getattr(self.apex_out, method_name)
578
579
        # Get the user output
580
        user_out = user_method(*method_args)
581
582
        # Evaluate the user output
583
        for out_val in user_out:
584
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
585
            assert out_val.shape == out_shape
586
        return
587
588
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
589
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
590
    @pytest.mark.parametrize("bad_lat", [91, -91])
591
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
592
        """Test convert raises ValueError for invalid latitudes."""
593
        if in_coord == out_coord:
594
            pytest.skip("Test not needed for same src and dest coordinates")
595
596
        # Get the desired methods
597
        method_name = "2".join([in_coord, out_coord])
598
        user_method = getattr(self.apex_out, method_name)
599
600
        with pytest.raises(ValueError) as verr:
601
            user_method(bad_lat, 15, 100)
602
603
        assert str(verr.value).find("must be in [-90, 90]") > 0
604
        return
605
606
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
607
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
608
    @pytest.mark.parametrize("bound_lat", [90, -90])
609
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
610
        """Test user methods at the latitude boundary, with allowed excess."""
611
        if in_coord == out_coord:
612
            pytest.skip("Test not needed for same src and dest coordinates")
613
614
        # Get the desired methods
615
        method_name = "2".join([in_coord, out_coord])
616
        user_method = getattr(self.apex_out, method_name)
617
618
        # Get a latitude just beyond the limit
619
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
620
621
        # Get the two outputs, slight tolerance outside of boundary allowed
622
        bound_out = user_method(bound_lat, 0, 100)
623
        excess_out = user_method(excess_lat, 0, 100)
624
625
        # Test the outputs
626
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
627
        return
628
629
    def test_geo2apex_undefined_warning(self):
630
        """Test geo2apex warning and fill values for an undefined location."""
631
632
        # Update the apex object
633
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
634
635
        # Get the output and the warnings
636
        with warnings.catch_warnings(record=True) as warn_rec:
637
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
638
639
        assert np.isnan(user_lat)
640
        assert np.isfinite(user_lon)
641
        assert len(warn_rec) == 1
642
        assert issubclass(warn_rec[-1].category, UserWarning)
643
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
644
        return
645
646
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
647
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
648
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
649
        """Test quasi-dipole success with a height close to the reference."""
650
        qd_method = getattr(self.apex_out, method_name)
651
        in_args = [0, 15, self.apex_out.refh + delta_h]
652
        out_coords = qd_method(*in_args)
653
654
        for i, out_val in enumerate(out_coords):
655
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
656
        return
657
658
    @pytest.mark.parametrize("method_name, hinc, msg",
659
                             [("apex2qd", 1.0, "is > apex height"),
660
                              ("qd2apex", -1.0, "is < reference height")])
661
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
662
        """Quasi-dipole raises ApexHeightError when height above reference."""
663
        qd_method = getattr(self.apex_out, method_name)
664
665
        with pytest.raises(apexpy.ApexHeightError) as aerr:
666
            qd_method(0, 15, self.apex_out.refh + hinc)
667
668
        assert str(aerr).find(msg) > 0
669
        return
670
671
672
class TestApexMLTMethods(object):
673
    """Test the Apex Magnetic Local Time (MLT) methods."""
674
    def setup_method(self):
675
        """Initialize all tests."""
676
        self.apex_out = apexpy.Apex(date=2000, refh=300)
677
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
678
679
    def teardown_method(self):
680
        """Clean up after each test."""
681
        del self.apex_out, self.in_time
682
683
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
684
    def test_convert_to_mlt(self, in_coord):
685
        """Test the conversions to MLT using Apex convert."""
686
687
        # Get the magnetic longitude from the appropriate method
688
        if in_coord == "geo":
689
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
690
            mlon = apex_method(60, 15, 100)[1]
691
        else:
692
            mlon = 15
693
694
        # Get the output MLT values
695
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
696
                                            height=100, ssheight=2e5,
697
                                            datetime=self.in_time)[1]
698
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
699
700
        # Test the outputs
701
        np.testing.assert_allclose(convert_mlt, method_mlt)
702
        return
703
704
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
705
    def test_convert_mlt_to_lon(self, out_coord):
706
        """Test the conversions from MLT using Apex convert."""
707
        # Get the output longitudes
708
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
709
                                            height=100, ssheight=2e5,
710
                                            datetime=self.in_time,
711
                                            precision=1e-2)
712
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
713
714
        if out_coord == "geo":
715
            method_out = self.apex_out.apex2geo(60, mlon, 100,
716
                                                precision=1e-2)[:-1]
717
        elif out_coord == "qd":
718
            method_out = self.apex_out.apex2qd(60, mlon, 100)
719
        else:
720
            method_out = (60, mlon)
721
722
        # Evaluate the outputs
723
        np.testing.assert_allclose(convert_out, method_out)
724
        return
725
726
    def test_convert_geo2mlt_nodate(self):
727
        """Test convert from geo to MLT raises ValueError with no datetime."""
728
        with pytest.raises(ValueError):
729
            self.apex_out.convert(60, 15, 'geo', 'mlt')
730
        return
731
732
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
733
                             [({}, 23.019629923502603),
734
                              ({"ssheight": 100000}, 23.026712036132814)])
735
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
736
        """Test mlon2mlt with scalar inputs."""
737
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
738
739
        np.testing.assert_allclose(mlt, test_mlt)
740
        assert np.asarray(mlt).shape == ()
741
        return
742
743
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
744
                             [({}, 14.705535888671875),
745
                              ({"ssheight": 100000}, 14.599319458007812)])
746
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
747
        """Test mlt2mlon with scalar inputs."""
748
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
749
750
        np.testing.assert_allclose(mlon, test_mlon)
751
        assert np.asarray(mlon).shape == ()
752
        return
753
754
    @pytest.mark.parametrize("mlon,test_mlt",
755
                             [([0, 180], [23.019261, 11.019261]),
756
                              (np.array([0, 180]), [23.019261, 11.019261]),
757
                              (np.array([[0], [180]]),
758
                               np.array([[23.019261], [11.019261]])),
759
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
760
                                                      [23.019261, 11.019261]]),
761
                              (range(0, 361, 30),
762
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
763
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
764
                                19.01963, 21.01963, 23.01963])])
765
    def test_mlon2mlt_array(self, mlon, test_mlt):
766
        """Test mlon2mlt with array inputs."""
767
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
768
769
        assert mlt.shape == np.asarray(test_mlt).shape
770
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
771
        return
772
773
    @pytest.mark.parametrize("mlt,test_mlon",
774
                             [([0, 12], [14.705551, 194.705551]),
775
                              (np.array([0, 12]), [14.705551, 194.705551]),
776
                              (np.array([[0], [12]]),
777
                               np.array([[14.705551], [194.705551]])),
778
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
779
                                                    [14.705551, 194.705551]]),
780
                              (range(0, 25, 2),
781
                               [14.705551, 44.705551, 74.705551, 104.705551,
782
                                134.705551, 164.705551, 194.705551, 224.705551,
783
                                254.705551, 284.705551, 314.705551, 344.705551,
784
                                14.705551])])
785
    def test_mlt2mlon_array(self, mlt, test_mlon):
786
        """Test mlt2mlon with array inputs."""
787
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
788
789
        assert mlon.shape == np.asarray(test_mlon).shape
790
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
791
        return
792
793
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
794
    def test_mlon2mlt_diffdates(self, method_name):
795
        """Test that MLT varies with universal time."""
796
        apex_method = getattr(self.apex_out, method_name)
797
        mlt1 = apex_method(0, self.in_time)
798
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
799
800
        assert mlt1 != mlt2
801
        return
802
803
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
804
    def test_mlon2mlt_offset(self, mlt_offset):
805
        """Test the time wrapping logic for the MLT."""
806
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
807
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
808
                                      self.in_time) + mlt_offset
809
810
        np.testing.assert_allclose(mlt1, mlt2)
811
        return
812
813
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
814
    def test_mlt2mlon_offset(self, mlon_offset):
815
        """Test the time wrapping logic for the magnetic longitude."""
816
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
817
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
818
                                       self.in_time) - mlon_offset
819
820
        np.testing.assert_allclose(mlon1, mlon2)
821
        return
822
823
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
824
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
825
    def test_convert_and_return(self, order, start_val):
826
        """Test the conversion to magnetic longitude or MLT and back again."""
827
        first_method = getattr(self.apex_out, "2".join(order))
828
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
829
830
        middle_val = first_method(start_val, self.in_time)
831
        end_val = second_method(middle_val, self.in_time)
832
833
        np.testing.assert_allclose(start_val, end_val)
834
        return
835
836
837
class TestApexMapMethods(object):
838
    """Test the Apex height mapping methods."""
839
    def setup_method(self):
840
        """Initialize all tests."""
841
        self.apex_out = apexpy.Apex(date=2000, refh=300)
842
843
    def teardown_method(self):
844
        """Clean up after each test."""
845
        del self.apex_out
846
847
    @pytest.mark.parametrize("in_args,test_mapped",
848
                             [([60, 15, 100, 10000],
849
                               [31.841466903686523, 17.916635513305664,
850
                                1.7075473124350538e-6]),
851
                              ([30, 170, 100, 500, False, 1e-2],
852
                               [25.727270126342773, 169.60546875,
853
                                0.00017573432705830783]),
854
                              ([60, 15, 100, 10000, True],
855
                               [-25.424888610839844, 27.310426712036133,
856
                                1.2074182222931995e-6]),
857
                              ([30, 170, 100, 500, True, 1e-2],
858
                               [-13.76642894744873, 164.24259948730469,
859
                                0.00056820799363777041])])
860
    def test_map_to_height(self, in_args, test_mapped):
861
        """Test the map_to_height function.
862
863
        Parameters
864
        ----------
865
        in_args : list
866
            List of input arguments
867
        test_mapped : list
868
            List of expected outputs
869
870
        """
871
        mapped = self.apex_out.map_to_height(*in_args)
872
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
873
        return
874
875
    def test_map_to_height_same_height(self):
876
        """Test the map_to_height function when mapping to same height."""
877
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
878
                                             precision=1e-10)
879
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
880
                                   rtol=1e-5, atol=1e-5)
881
        return
882
883
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
884
    @pytest.mark.parametrize('ivec', range(0, 4))
885
    def test_map_to_height_array_location(self, arr_shape, ivec):
886
        """Test map_to_height with array input.
887
888
        Parameters
889
        ----------
890
        arr_shape : tuple
891
            Expected array shape
892
        ivec : int
893
            Input argument index for vectorized input
894
895
        """
896
        # Set the base input and output values
897
        in_args = [60, 15, 100, 100]
898
        test_mapped = [60, 15.00000381, 0.0]
899
900
        # Update inputs for one vectorized value
901
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
902
903
        # Calculate and test function
904
        mapped = self.apex_out.map_to_height(*in_args)
905
        for i, test_val in enumerate(test_mapped):
906
            assert mapped[i].shape == arr_shape
907
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
908
                                       atol=1e-5)
909
        return
910
911
    @pytest.mark.parametrize("method_name,in_args",
912
                             [("map_to_height", [0, 15, 100, 10000]),
913
                              ("map_E_to_height",
914
                               [0, 15, 100, 10000, [1, 2, 3]]),
915
                              ("map_V_to_height",
916
                               [0, 15, 100, 10000, [1, 2, 3]])])
917
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
918
        """Test map_to_height raises ApexHeightError."""
919
        apex_method = getattr(self.apex_out, method_name)
920
921
        with pytest.raises(apexpy.ApexHeightError) as aerr:
922
            apex_method(*in_args)
923
924
        assert aerr.match("is > apex height")
925
        return
926
927
    @pytest.mark.parametrize("method_name",
928
                             ["map_E_to_height", "map_V_to_height"])
929
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
930
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
931
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
932
        """Test height mapping of E/V with baddly shaped input raises Error."""
933
        apex_method = getattr(self.apex_out, method_name)
934
        in_args = [60, 15, 100, 500, ev_input]
935
        with pytest.raises(ValueError) as verr:
936
            apex_method(*in_args)
937
938
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
939
        return
940
941
    def test_mapping_EV_bad_flag(self):
942
        """Test _map_EV_to_height raises error for bad data type flag."""
943
        with pytest.raises(ValueError) as verr:
944
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
945
946
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
947
        return
948
949
    @pytest.mark.parametrize("in_args,test_mapped",
950
                             [([60, 15, 100, 500, [1, 2, 3]],
951
                               [0.71152183, 2.35624876, 0.57260784]),
952
                              ([60, 15, 100, 500, [2, 3, 4]],
953
                               [1.56028502, 3.43916636, 0.78235384]),
954
                              ([60, 15, 100, 1000, [1, 2, 3]],
955
                               [0.67796492, 2.08982134, 0.55860785]),
956
                              ([60, 15, 200, 500, [1, 2, 3]],
957
                               [0.72377397, 2.42737471, 0.59083726]),
958
                              ([60, 30, 100, 500, [1, 2, 3]],
959
                               [0.68626344, 2.37530133, 0.60060124]),
960
                              ([70, 15, 100, 500, [1, 2, 3]],
961
                               [0.72760378, 2.18082305, 0.29141979])])
962
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
963
        """Test mapping of E-field to a specified height."""
964
        mapped = self.apex_out.map_E_to_height(*in_args)
965
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
966
        return
967
968
    @pytest.mark.parametrize('ev_flag, test_mapped',
969
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
970
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
971
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
972
    @pytest.mark.parametrize('ivec', range(0, 5))
973
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
974
                                             arr_shape, ivec):
975
        """Test mapping of E-field/drift to a specified height with arrays."""
976
        # Set the base input and output values
977
        eshape = list(arr_shape)
978
        eshape.insert(0, 3)
979
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
980
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
981
982
        # Update inputs for one vectorized value if this is a location input
983
        if ivec < 4:
984
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
985
986
        # Get the mapped output
987
        apex_method = getattr(self.apex_out,
988
                              "map_{:s}_to_height".format(ev_flag))
989
        mapped = apex_method(*in_args)
990
991
        # Test the results
992
        for i, test_val in enumerate(test_mapped):
993
            assert mapped[i].shape == arr_shape
994
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
995
        return
996
997
    @pytest.mark.parametrize("in_args,test_mapped",
998
                             [([60, 15, 100, 500, [1, 2, 3]],
999
                               [0.81971957, 2.84512495, 0.69545001]),
1000
                              ([60, 15, 100, 500, [2, 3, 4]],
1001
                               [1.83027746, 4.14346436, 0.94764179]),
1002
                              ([60, 15, 100, 1000, [1, 2, 3]],
1003
                               [0.92457698, 3.14997661, 0.85135187]),
1004
                              ([60, 15, 200, 500, [1, 2, 3]],
1005
                               [0.80388262, 2.79321504, 0.68285158]),
1006
                              ([60, 30, 100, 500, [1, 2, 3]],
1007
                               [0.76141245, 2.87884673, 0.73655941]),
1008
                              ([70, 15, 100, 500, [1, 2, 3]],
1009
                               [0.84681866, 2.5925821, 0.34792655])])
1010
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1011
        """Test mapping of velocity to a specified height."""
1012
        mapped = self.apex_out.map_V_to_height(*in_args)
1013
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1014
        return
1015
1016
1017
class TestApexBasevectorMethods(object):
1018
    """Test the Apex height base vector methods."""
1019
    def setup_method(self):
1020
        """Initialize all tests."""
1021
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1022
        self.lat = 60
1023
        self.lon = 15
1024
        self.height = 100
1025
        self.test_basevec = None
1026
1027
    def teardown_method(self):
1028
        """Clean up after each test."""
1029
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1030
1031
    def get_comparison_results(self, bv_coord, coords, precision):
1032
        """Get the base vector results using the hidden function for comparison.
1033
1034
        Parameters
1035
        ----------
1036
        bv_coord : str
1037
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1038
            or 'bvectors_apex'
1039
        coords : str
1040
            Expects one of 'geo', 'apex', or 'qd'
1041
        precision : float
1042
            Float specifiying precision
1043
1044
        """
1045
        if coords == "geo":
1046
            glat = self.lat
1047
            glon = self.lon
1048
        else:
1049
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1050
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1051
                                        precision=precision)
1052
1053
        if bv_coord == 'qd':
1054
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1055
        elif bv_coord == 'apex':
1056
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1057
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1058
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1059
        else:
1060
            # These are set results that need to be updated with IGRF
1061
            if coords == "geo":
1062
                self.test_basevec = (
1063
                    np.array([4.42368795e-05, 4.42368795e-05]),
1064
                    np.array([[0.01047826, 0.01047826],
1065
                              [0.33089194, 0.33089194],
1066
                              [-1.04941, -1.04941]]),
1067
                    np.array([5.3564698e-05, 5.3564698e-05]),
1068
                    np.array([[0.00865356, 0.00865356],
1069
                              [0.27327004, 0.27327004],
1070
                              [-0.8666646, -0.8666646]]))
1071
            elif coords == "apex":
1072
                self.test_basevec = (
1073
                    np.array([4.48672735e-05, 4.48672735e-05]),
1074
                    np.array([[-0.12510721, -0.12510721],
1075
                              [0.28945938, 0.28945938],
1076
                              [-1.1505738, -1.1505738]]),
1077
                    np.array([6.38577444e-05, 6.38577444e-05]),
1078
                    np.array([[-0.08790194, -0.08790194],
1079
                              [0.2033779, 0.2033779],
1080
                              [-0.808408, -0.808408]]))
1081
            else:
1082
                self.test_basevec = (
1083
                    np.array([4.46348578e-05, 4.46348578e-05]),
1084
                    np.array([[-0.12642345, -0.12642345],
1085
                              [0.29695055, 0.29695055],
1086
                              [-1.1517885, -1.1517885]]),
1087
                    np.array([6.38626285e-05, 6.38626285e-05]),
1088
                    np.array([[-0.08835986, -0.08835986],
1089
                              [0.20754464, 0.20754464],
1090
                              [-0.8050078, -0.8050078]]))
1091
1092
        return
1093
1094
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1095
    @pytest.mark.parametrize("coords,precision",
1096
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1097
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1098
        """Test the base vector calculations with scalars."""
1099
        # Get the base vectors
1100
        base_method = getattr(self.apex_out,
1101
                              "basevectors_{:s}".format(bv_coord))
1102
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1103
                              precision=precision)
1104
        self.get_comparison_results(bv_coord, coords, precision)
1105
        if bv_coord == "apex":
1106
            basevec = list(basevec)
1107
            for i in range(4):
1108
                # Not able to compare indices 2, 3, 4, and 5
1109
                basevec.pop(2)
1110
1111
        # Test the results
1112
        for i, vec in enumerate(basevec):
1113
            np.testing.assert_allclose(vec, self.test_basevec[i])
1114
        return
1115
1116
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1117
    def test_basevectors_scalar_shape(self, bv_coord):
1118
        """Test the shape of the scalar output."""
1119
        base_method = getattr(self.apex_out,
1120
                              "basevectors_{:s}".format(bv_coord))
1121
        basevec = base_method(self.lat, self.lon, self.height)
1122
1123
        for i, vec in enumerate(basevec):
1124
            if i < 2:
1125
                assert vec.shape == (2,)
1126
            else:
1127
                assert vec.shape == (3,)
1128
        return
1129
1130
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1131
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1132
    @pytest.mark.parametrize("ivec", range(3))
1133
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1134
        """Test the output shape for array inputs."""
1135
        # Define the input arguments
1136
        in_args = [self.lat, self.lon, self.height]
1137
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1138
1139
        # Get the basevectors
1140
        base_method = getattr(self.apex_out,
1141
                              "basevectors_{:s}".format(bv_coord))
1142
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1143
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1144
        if bv_coord == "apex":
1145
            basevec = list(basevec)
1146
            for i in range(4):
1147
                # Not able to compare indices 2, 3, 4, and 5
1148
                basevec.pop(2)
1149
1150
        # Evaluate the shape and the values
1151
        for i, vec in enumerate(basevec):
1152
            test_shape = list(arr_shape)
1153
            test_shape.insert(0, 2 if i < 2 else 3)
1154
            assert vec.shape == tuple(test_shape)
1155
            assert np.all(self.test_basevec[i][0] == vec[0])
1156
            assert np.all(self.test_basevec[i][1] == vec[1])
1157
        return
1158
1159
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1160
    def test_bvectors_apex(self, coords):
1161
        """Test the bvectors_apex method."""
1162
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1163
                   [self.height, self.height]]
1164
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1165
1166
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1167
                                              precision=1e-10)
1168
        for i, vec in enumerate(basevec):
1169
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1170
                                                 decimal=5)
1171
        return
1172
1173
    def test_basevectors_apex_extra_values(self):
1174
        """Test specific values in the apex base vector output."""
1175
        # Set the testing arrays
1176
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1177
                             np.array([0.939012, 0.073416, -0.07342]),
1178
                             np.array([0.055389, 1.004155, 0.257594]),
1179
                             np.array([0, 0, 1.065135])]
1180
1181
        # Get the desired output
1182
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1183
1184
        # Test the values not covered by `test_basevectors_scalar`
1185
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1186
            np.testing.assert_allclose(basevec[ibase],
1187
                                       self.test_basevec[itest], rtol=1e-4)
1188
        return
1189
1190
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1191
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1192
    def test_basevectors_apex_delta(self, lat, lon):
1193
        """Test that vectors are calculated correctly."""
1194
        # Get the apex base vectors and sort them for easy testing
1195
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1196
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1197
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1198
        gvec = [g1, g2, g3]
1199
        dvec = [d1, d2, d3]
1200
        evec = [e1, e2, e3]
1201
1202
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1203
            delta = 1 if idelta == jdelta else 0
1204
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1205
                                       delta, rtol=0, atol=1e-5)
1206
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1207
                                       delta, rtol=0, atol=1e-5)
1208
        return
1209
1210
    def test_basevectors_apex_invalid_scalar(self):
1211
        """Test warning and fill values for base vectors with bad inputs."""
1212
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1213
        invalid = np.full(shape=(3,), fill_value=np.nan)
1214
1215
        # Get the output and the warnings
1216
        with warnings.catch_warnings(record=True) as warn_rec:
1217
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1218
1219
        for i, bvec in enumerate(basevec):
1220
            if i < 2:
1221
                assert not np.allclose(bvec, invalid[:2])
1222
            else:
1223
                np.testing.assert_allclose(bvec, invalid)
1224
1225
        assert issubclass(warn_rec[-1].category, UserWarning)
1226
        assert 'set to NaN where' in str(warn_rec[-1].message)
1227
        return
1228
1229
1230
class TestApexGetMethods(object):
1231
    """Test the Apex `get` methods."""
1232
    def setup_method(self):
1233
        """Initialize all tests."""
1234
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1235
1236
    def teardown_method(self):
1237
        """Clean up after each test."""
1238
        del self.apex_out
1239
1240
    @pytest.mark.parametrize("alat, aheight",
1241
                             [(10, 507.409702543805),
1242
                              (60, 20313.026999999987),
1243
                              ([10, 60],
1244
                               [507.409702543805, 20313.026999999987]),
1245
                              ([[10], [60]],
1246
                               [[507.409702543805], [20313.026999999987]])])
1247
    def test_get_apex(self, alat, aheight):
1248
        """Test the apex height retrieval results."""
1249
        alt = self.apex_out.get_apex(alat)
1250
        np.testing.assert_allclose(alt, aheight)
1251
        return
1252
1253
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1254
                             [([80], [100], [300], 5.100682377815247e-05),
1255
                              ([80, 80], [100], [300],
1256
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1257
                              ([[80], [80]], [100], [300],
1258
                               [[5.100682377815247e-05],
1259
                                [5.100682377815247e-05]]),
1260
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1261
                               np.array([4.18657154e-05, 5.11118114e-05,
1262
                                         4.91969854e-05, 5.10519207e-05,
1263
                                         4.90054816e-05])),
1264
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1265
    def test_get_babs(self, glat, glon, height, test_bmag):
1266
        """Test the method to get the magnitude of the magnetic field."""
1267
        bmag = self.apex_out.get_babs(glat, glon, height)
1268
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1269
        return
1270
1271
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1272
    def test_get_apex_with_invalid_lat(self, bad_lat):
1273
        """Test get methods raise ValueError for invalid latitudes."""
1274
1275
        with pytest.raises(ValueError) as verr:
1276
            self.apex_out.get_apex(bad_lat)
1277
1278
        assert str(verr.value).find("must be in [-90, 90]") > 0
1279
        return
1280
1281
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1282
    def test_get_babs_with_invalid_lat(self, bad_lat):
1283
        """Test get methods raise ValueError for invalid latitudes."""
1284
1285
        with pytest.raises(ValueError) as verr:
1286
            self.apex_out.get_babs(bad_lat, 15, 100)
1287
1288
        assert str(verr.value).find("must be in [-90, 90]") > 0
1289
        return
1290
1291
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1292
    def test_get_at_lat_boundary(self, bound_lat):
1293
        """Test get methods at the latitude boundary, with allowed excess."""
1294
        # Get a latitude just beyond the limit
1295
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1296
1297
        # Get the two outputs, slight tolerance outside of boundary allowed
1298
        bound_out = self.apex_out.get_apex(bound_lat)
1299
        excess_out = self.apex_out.get_apex(excess_lat)
1300
1301
        # Test the outputs
1302
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1303
        return
1304
1305
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1306
    def test_get_height_at_equator(self, apex_height):
1307
        """Test that `get_height` returns apex height at equator.
1308
1309
        Parameters
1310
        ----------
1311
        apex_height : float
1312
            Apex height
1313
1314
        """
1315
1316
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1317
        return
1318
1319
    @pytest.mark.parametrize("lat, height", [
1320
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1321
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1322
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1323
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1324
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1325
        (30, 657.2477500000014), (40, -871.8751821247979),
1326
        (50, -2499.1338178752017), (60, -4028.256749999999),
1327
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1328
    def test_get_height_along_fieldline(self, lat, height):
1329
        """Test that `get_height` returns expected height of field line.
1330
1331
        Parameters
1332
        ----------
1333
        lat : float
1334
            Input latitude
1335
        height : float
1336
            Output field-line height for line with apex of 3000 km
1337
1338
        """
1339
1340
        fheight = self.apex_out.get_height(lat, 3000.0)
1341
        assert abs(height - fheight) < 1.0e-7, \
1342
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1343
        return
1344