Completed
Pull Request — develop (#76)
by Angeline
04:31
created

TestApexMapMethods.test_mapping_EV_bad_flag()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import copy
16
import datetime as dt
17
import numpy as np
18
import os
19
import pytest
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file():
27
    """A fixture for handling the coefficient file."""
28
    # Ensure the coefficient file exists
29
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
30
                                 'igrf13coeffs.txt')
31
    tmp_file = "temp_coeff.txt"
32
    assert os.path.isfile(original_file)
33
34
    # Move the coefficient file
35
    os.rename(original_file, tmp_file)
36
    yield original_file
37
38
    # Move the coefficient file back
39
    os.rename(tmp_file, original_file)
40
    return
41
42
43
def test_set_epoch_file_error(igrf_file):
44
    """Test raises OSError when IGRF coefficient file is missing."""
45
    # Test missing coefficient file failure
46
    with pytest.raises(OSError) as oerr:
47
        apexpy.Apex()
48
    error_string = "File {:} does not exist".format(igrf_file)
49
    assert str(oerr.value).startswith(error_string)
50
    return
51
52
53
class TestApexInit():
54
    def setup(self):
55
        self.apex_out = None
56
        self.test_date = dt.datetime.utcnow()
57
        self.test_refh = 0
58
        self.bad_file = 'foo/path/to/datafile.blah'
59
60
    def teardown(self):
61
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
62
63
    def eval_date(self):
64
        """Evaluate the times in self.test_date and self.apex_out."""
65
        if isinstance(self.test_date, dt.datetime) \
66
           or isinstance(self.test_date, dt.date):
67
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
68
69
        # Assert the times are the same on the order of tens of seconds.
70
        # Necessary to evaluate the current UTC
71
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
72
        return
73
74
    def eval_refh(self):
75
        """Evaluate the reference height in self.refh and self.apex_out."""
76
        eval_str = "".join(["expected reference height [",
77
                            "{:}] not equal to Apex ".format(self.test_refh),
78
                            "reference height ",
79
                            "[{:}]".format(self.apex_out.refh)])
80
        assert self.test_refh == self.apex_out.refh, eval_str
81
        return
82
83
    def test_init_defaults(self):
84
        """Test Apex class default initialization."""
85
        self.apex_out = apexpy.Apex()
86
        self.eval_date()
87
        self.eval_refh()
88
        return
89
90
    @pytest.mark.parametrize("in_date",
91
                             [2015, 2015.5, dt.date(2015, 1, 1),
92
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
93
    def test_init_date(self, in_date):
94
        """Test Apex class with date initialization."""
95
        self.test_date = in_date
96
        self.apex_out = apexpy.Apex(date=self.test_date)
97
        self.eval_date()
98
        self.eval_refh()
99
        return
100
101
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
102
    def test_set_epoch(self, new_date):
103
        """Test successful setting of Apex epoch after initialization."""
104
        # Evaluate the default initialization
105
        self.apex_out = apexpy.Apex()
106
        self.eval_date()
107
        self.eval_refh()
108
109
        # Update the epoch
110
        ref_apex = copy.deepcopy(self.apex_out)
111
        self.apex_out.set_epoch(new_date)
112
        assert ref_apex != self.apex_out
113
        self.test_date = new_date
114
        self.eval_date()
115
        return
116
117
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
118
    def test_init_refh(self, in_refh):
119
        """Test Apex class with reference height initialization."""
120
        self.test_refh = in_refh
121
        self.apex_out = apexpy.Apex(refh=self.test_refh)
122
        self.eval_date()
123
        self.eval_refh()
124
        return
125
126
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
127
    def test_set_refh(self, new_refh):
128
        """Test the method used to set the reference height after the init."""
129
        # Verify the defaults are set
130
        self.apex_out = apexpy.Apex(date=self.test_date)
131
        self.eval_date()
132
        self.eval_refh()
133
134
        # Update to a new reference height and test
135
        ref_apex = copy.deepcopy(self.apex_out)
136
        self.apex_out.set_refh(new_refh)
137
138
        if self.test_refh == new_refh:
139
            assert ref_apex == self.apex_out
140
        else:
141
            assert ref_apex != self.apex_out
142
            self.test_refh = new_refh
143
        self.eval_refh()
144
        return
145
146
    def test_init_with_bad_datafile(self):
147
        """Test raises IOError with non-existent datafile input."""
148
        with pytest.raises(IOError) as oerr:
149
            apexpy.Apex(datafile=self.bad_file)
150
        assert str(oerr.value).startswith('Data file does not exist')
151
        return
152
153
    def test_init_with_bad_fortranlib(self):
154
        """Test raises IOError with non-existent datafile input."""
155
        with pytest.raises(IOError) as oerr:
156
            apexpy.Apex(fortranlib=self.bad_file)
157
        assert str(oerr.value).startswith('Fortran library does not exist')
158
        return
159
160
    def test_repr_eval(self):
161
        """Test the Apex.__repr__ results."""
162
        # Initialize the apex object
163
        self.apex_out = apexpy.Apex()
164
        self.eval_date()
165
        self.eval_refh()
166
167
        # Get and test the repr string
168
        out_str = self.apex_out.__repr__()
169
        assert out_str.find("apexpy.Apex(") == 0
170
171
        # Test the ability to re-create the apex object from the repr string
172
        new_apex = eval(out_str)
173
        assert new_apex == self.apex_out
174
        return
175
176
    def test_str_eval(self):
177
        """Test the Apex.__str__ results."""
178
        # Initialize the apex object
179
        self.apex_out = apexpy.Apex()
180
        self.eval_date()
181
        self.eval_refh()
182
183
        # Get and test the printed string
184
        out_str = self.apex_out.__str__()
185
        assert out_str.find("Decimal year") > 0
186
        return
187
188
189
class TestApexMethod():
190
    """Test the Apex methods."""
191
    def setup(self):
192
        """Initialize all tests."""
193
        self.apex_out = apexpy.Apex(date=2000, refh=300)
194
        self.in_lat = 60
195
        self.in_lon = 15
196
        self.in_alt = 100
197
198
    def teardown(self):
199
        """Clean up after each test."""
200
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
201
202
    def get_input_args(self, method_name, precision=0.0):
203
        """Set the input arguments for the different Apex methods.
204
205
        Parameters
206
        ----------
207
        method_name : str
208
            Name of the Apex class method
209
        precision : float
210
            Value for the precision (default=0.0)
211
212
        Returns
213
        -------
214
        in_args : list
215
            List of the appropriate input arguments
216
217
        """
218
        in_args = [self.in_lat, self.in_lon, self.in_alt]
219
220
        # Add precision, if needed
221
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
222
                           "_apex2geo"]:
223
            in_args.append(precision)
224
225
        # Add a reference height, if needed
226
        if method_name in ["apxg2all"]:
227
            in_args.append(300)
228
229
        # Add a vector flag, if needed
230
        if method_name in ["apxg2all", "apxg2q"]:
231
            in_args.append(1)
232
233
        return in_args
234
235
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
236
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
237
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
238
                              ("_qd2geo", "apxq2g", slice(None)),
239
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
240
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
241
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
242
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
243
                                  lat, lon):
244
        """Tests Apex/fortran interface consistency for scalars."""
245
        # Set the input coordinates
246
        self.in_lat = lat
247
        self.in_lon = lon
248
249
        # Get the Apex class method and the fortran function call
250
        apex_func = getattr(self.apex_out, apex_method)
251
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
252
253
        # Get the appropriate input arguments
254
        apex_args = self.get_input_args(apex_method)
255
        fortran_args = self.get_input_args(fortran_method)
256
257
        # Evaluate the equivalent function calls
258
        np.testing.assert_allclose(apex_func(*apex_args),
259
                                   fortran_func(*fortran_args)[fslice])
260
        return
261
262
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
263
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
264
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
265
                              ("_qd2geo", "apxq2g", slice(None)),
266
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
267
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
268
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
269
                                           (180, -180), (-180, 180),
270
                                           (-345, 15), (375, 15)])
271
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
272
                                        fslice, lat, lon1, lon2):
273
        """Tests Apex/fortran interface consistency for longitude rollover."""
274
        # Set the fixed input coordinate
275
        self.in_lat = lat
276
277
        # Get the Apex class method and the fortran function call
278
        apex_func = getattr(self.apex_out, apex_method)
279
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
280
281
        # Get the appropriate input arguments
282
        self.in_lon = lon1
283
        apex_args = self.get_input_args(apex_method)
284
285
        self.in_lon = lon2
286
        fortran_args = self.get_input_args(fortran_method)
287
288
        # Evaluate the equivalent function calls
289
        np.testing.assert_allclose(apex_func(*apex_args),
290
                                   fortran_func(*fortran_args)[fslice])
291
        return
292
293
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
294
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
295
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
296
                              ("_qd2geo", "apxq2g", slice(None)),
297
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
298
    def test_fortran_array_input(self, apex_method, fortran_method, fslice):
299
        """Tests Apex/fortran interface consistency for array input."""
300
        # Get the Apex class method and the fortran function call
301
        apex_func = getattr(self.apex_out, apex_method)
302
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
303
304
        # Set up the input arrays
305
        ref_lat = np.array([0, 30, 60, 90])
306
        ref_alt = np.array([100, 200, 300, 400])
307
        self.in_lat = ref_lat.reshape((2, 2))
308
        self.in_alt = ref_alt.reshape((2, 2))
309
        apex_args = self.get_input_args(apex_method)
310
311
        # Get the Apex class results
312
        aret = apex_func(*apex_args)
313
314
        # Get the fortran function results
315
        flats = list()
316
        flons = list()
317
318
        for i, lat in enumerate(ref_lat):
319
            self.in_lat = lat
320
            self.in_alt = ref_alt[i]
321
            fortran_args = self.get_input_args(fortran_method)
322
            fret = fortran_func(*fortran_args)[fslice]
323
            flats.append(fret[0])
324
            flons.append(fret[1])
325
326
        flats = np.array(flats)
327
        flons = np.array(flons)
328
329
        # Evaluate results
330
        try:
331
            # This returned value is array of floats
332
            np.testing.assert_allclose(aret[0].astype(float),
333
                                       flats.reshape((2, 2)).astype(float))
334
            np.testing.assert_allclose(aret[1].astype(float),
335
                                       flons.reshape((2, 2)).astype(float))
336
        except ValueError:
337
            # This returned value is array of arrays
338
            alats = aret[0].reshape((4,))
339
            alons = aret[1].reshape((4,))
340
            for i, flat in enumerate(flats):
341
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
342
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
343
344
        return
345
346
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
347
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
348
    def test_geo2apexall_scalar(self, lat, lon):
349
        """Test Apex/fortran geo2apexall interface consistency for scalars."""
350
        # Get the Apex and Fortran results
351
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
352
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
353
354
        # Evaluate each element in the results
355
        for aval, fval in zip(aret, fret):
356
            np.testing.assert_allclose(aval, fval)
357
358
    def test_geo2apexall_array(self):
359
        """Test Apex/fortran geo2apexall interface consistency for arrays."""
360
        # Set the input
361
        self.in_lat = np.array([0, 30, 60, 90])
362
        self.in_alt = np.array([100, 200, 300, 400])
363
364
        # Get the Apex class results
365
        aret = self.apex_out._geo2apexall(self.in_lat.reshape((2, 2)),
366
                                          self.in_lon,
367
                                          self.in_alt.reshape((2, 2)))
368
369
        # For each lat/alt pair, get the Fortran results
370
        fret = list()
371
        for i, lat in enumerate(self.in_lat):
372
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
373
                                                    self.in_alt[i], 300, 1))
374
375
        # Cycle through all returned values
376
        for i, ret in enumerate(aret):
377
            try:
378
                # This returned value is array of floats
379
                np.testing.assert_allclose(ret.astype(float),
380
                                           np.array([[fret[0][i], fret[1][i]],
381
                                                     [fret[2][i], fret[3][i]]],
382
                                                    dtype=float))
383
            except ValueError:
384
                # This returned value is array of arrays
385
                ret = ret.reshape((4,))
386
                for j, single_fret in enumerate(fret):
387
                    np.testing.assert_allclose(ret[j], single_fret[i])
388
        return
389
390
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
391
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
392
    def test_convert_consistency(self, in_coord, out_coord):
393
        """Test the self-consistency of the Apex convert method."""
394
        if in_coord == out_coord:
395
            pytest.skip("Test not needed for same src and dest coordinates")
396
397
        # Define the method name
398
        method_name = "2".join([in_coord, out_coord])
399
400
        # Get the method and method inputs
401
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
402
        apex_args = self.get_input_args(method_name)
403
        apex_method = getattr(self.apex_out, method_name)
404
405
        # Define the slice needed to get equivalent output from the named method
406
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
407
408
        # Get output using convert and named method
409
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
410
                                            out_coord, **convert_kwargs)
411
        method_out = apex_method(*apex_args)[mslice]
412
413
        # Compare both outputs, should be identical
414
        np.testing.assert_allclose(convert_out, method_out)
415
        return
416
417
    @pytest.mark.parametrize("bound_lat", [90, -90])
418
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
419
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
420
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
421
        """Test the conversion at the latitude boundary, with allowed excess."""
422
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
423
424
        # Get the two outputs, slight tolerance outside of boundary allowed
425
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
426
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
427
428
        # Test the outputs
429
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
430
        return
431
432
    def test_convert_qd2apex_at_equator(self):
433
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
434
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
435
                                       height=320.0)
436
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
437
                                          dest='apex', height=320.0)
438
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
439
        return
440
441
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
442
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
443
    def test_convert_withnan(self, src, dest):
444
        """Test Apex.convert success with NaN input."""
445
        if src == dest:
446
            pytest.skip("Test not needed for same src and dest coordinates")
447
448
        num_nans = 5
449
        in_loc = np.arange(0, 10, dtype=float)
450
        in_loc[:num_nans] = np.nan
451
452
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
453
454
        for out in out_loc:
455
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
456
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
457
458
        return
459
460
    @pytest.mark.parametrize("bad_lat", [91, -91])
461
    def test_convert_invalid_lat(self, bad_lat):
462
        """Test convert raises ValueError for invalid latitudes."""
463
464
        with pytest.raises(ValueError) as verr:
465
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
466
467
        assert str(verr.value).find("must be in [-90, 90]") > 0
468
        return
469
470
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
471
                                        ("geo", "mlt")])
472
    def test_convert_invalid_transformation(self, coords):
473
        """Test raises NotImplementedError for bad coordinates."""
474
        if "mlt" in coords:
475
            estr = "datetime must be given for MLT calculations"
476
        else:
477
            estr = "Unknown coordinate transformation"
478
479
        with pytest.raises(ValueError) as verr:
480
            self.apex_out.convert(0, 0, *coords)
481
482
        assert str(verr).find(estr) >= 0
483
        return
484
485
    @pytest.mark.parametrize("method_name, out_comp",
486
                             [("geo2apex",
487
                               (55.94841766357422, 94.10684204101562)),
488
                              ("apex2geo",
489
                               (51.476322174072266, -66.22817993164062,
490
                                5.727287771151168e-06)),
491
                              ("geo2qd",
492
                               (56.531288146972656, 94.10684204101562)),
493
                              ("apex2qd", (60.498401178276744, 15.0)),
494
                              ("qd2apex", (59.49138097045895, 15.0))])
495
    def test_method_scalar_input(self, method_name, out_comp):
496
        """Test the user method against set values with scalars."""
497
        # Get the desired methods
498
        user_method = getattr(self.apex_out, method_name)
499
500
        # Get the user output
501
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
502
503
        # Evaluate the user output
504
        np.testing.assert_allclose(user_out, out_comp)
505
506
        for out_val in user_out:
507
            assert np.asarray(out_val).shape == (), "output is not a scalar"
508
        return
509
510
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
511
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
512
    @pytest.mark.parametrize("method_args, out_shape",
513
                             [([[60, 60], 15, 100], (2,)),
514
                              ([60, [15, 15], 100], (2,)),
515
                              ([60, 15, [100, 100]], (2,)),
516
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
517
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
518
                                    out_shape):
519
        """Test the user method with inputs that require some broadcasting."""
520
        if in_coord == out_coord:
521
            pytest.skip("Test not needed for same src and dest coordinates")
522
523
        # Get the desired methods
524
        method_name = "2".join([in_coord, out_coord])
525
        user_method = getattr(self.apex_out, method_name)
526
527
        # Get the user output
528
        user_out = user_method(*method_args)
529
530
        # Evaluate the user output
531
        for out_val in user_out:
532
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
533
            assert out_val.shape == out_shape
534
        return
535
536
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
537
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
538
    @pytest.mark.parametrize("bad_lat", [91, -91])
539
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
540
        """Test convert raises ValueError for invalid latitudes."""
541
        if in_coord == out_coord:
542
            pytest.skip("Test not needed for same src and dest coordinates")
543
544
        # Get the desired methods
545
        method_name = "2".join([in_coord, out_coord])
546
        user_method = getattr(self.apex_out, method_name)
547
548
        with pytest.raises(ValueError) as verr:
549
            user_method(bad_lat, 15, 100)
550
551
        assert str(verr.value).find("must be in [-90, 90]") > 0
552
        return
553
554
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
555
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
556
    @pytest.mark.parametrize("bound_lat", [90, -90])
557
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
558
        """Test user methods at the latitude boundary, with allowed excess."""
559
        if in_coord == out_coord:
560
            pytest.skip("Test not needed for same src and dest coordinates")
561
562
        # Get the desired methods
563
        method_name = "2".join([in_coord, out_coord])
564
        user_method = getattr(self.apex_out, method_name)
565
566
        # Get a latitude just beyond the limit
567
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
568
569
        # Get the two outputs, slight tolerance outside of boundary allowed
570
        bound_out = user_method(bound_lat, 0, 100)
571
        excess_out = user_method(excess_lat, 0, 100)
572
573
        # Test the outputs
574
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
575
        return
576
577
    def test_geo2apex_undefined_warning(self):
578
        """Test geo2apex warning and fill values for an undefined location."""
579
580
        # Update the apex object
581
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
582
583
        # Get the output and the warnings
584
        with warnings.catch_warnings(record=True) as warn_rec:
585
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
586
587
        assert np.isnan(user_lat)
588
        assert np.isfinite(user_lon)
589
        assert len(warn_rec) == 1
590
        assert issubclass(warn_rec[-1].category, UserWarning)
591
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
592
        return
593
594
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
595
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
596
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
597
        """Test quasi-dipole success with a height close to the reference."""
598
        qd_method = getattr(self.apex_out, method_name)
599
        in_args = [0, 15, self.apex_out.refh + delta_h]
600
        out_coords = qd_method(*in_args)
601
602
        for i, out_val in enumerate(out_coords):
603
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
604
        return
605
606
    @pytest.mark.parametrize("method_name, hinc, msg",
607
                             [("apex2qd", 1.0, "is > apex height"),
608
                              ("qd2apex", -1.0, "is < reference height")])
609
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
610
        """Quasi-dipole raises ApexHeightError when height above reference."""
611
        qd_method = getattr(self.apex_out, method_name)
612
613
        with pytest.raises(apexpy.ApexHeightError) as aerr:
614
            qd_method(0, 15, self.apex_out.refh + hinc)
615
616
        assert str(aerr).find(msg) > 0
617
        return
618
619
620
class TestApexMLTMethods():
621
    """Test the Apex Magnetic Local Time (MLT) methods."""
622
    def setup(self):
623
        """Initialize all tests."""
624
        self.apex_out = apexpy.Apex(date=2000, refh=300)
625
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
626
627
    def teardown(self):
628
        """Clean up after each test."""
629
        del self.apex_out, self.in_time
630
631
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
632
    def test_convert_to_mlt(self, in_coord):
633
        """Test the conversions to MLT using Apex convert."""
634
635
        # Get the magnetic longitude from the appropriate method
636
        if in_coord == "geo":
637
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
638
            mlon = apex_method(60, 15, 100)[1]
639
        else:
640
            mlon = 15
641
642
        # Get the output MLT values
643
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
644
                                            height=100, ssheight=2e5,
645
                                            datetime=self.in_time)[1]
646
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
647
648
        # Test the outputs
649
        np.testing.assert_allclose(convert_mlt, method_mlt)
650
        return
651
652
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
653
    def test_convert_mlt_to_lon(self, out_coord):
654
        """Test the conversions from MLT using Apex convert."""
655
        # Get the output longitudes
656
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
657
                                            height=100, ssheight=2e5,
658
                                            datetime=self.in_time,
659
                                            precision=1e-2)
660
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
661
662
        if out_coord == "geo":
663
            method_out = self.apex_out.apex2geo(60, mlon, 100,
664
                                                precision=1e-2)[:-1]
665
        elif out_coord == "qd":
666
            method_out = self.apex_out.apex2qd(60, mlon, 100)
667
        else:
668
            method_out = (60, mlon)
669
670
        # Evaluate the outputs
671
        np.testing.assert_allclose(convert_out, method_out)
672
        return
673
674
    def test_convert_geo2mlt_nodate(self):
675
        """Test convert from geo to MLT raises ValueError with no datetime."""
676
        with pytest.raises(ValueError):
677
            self.apex_out.convert(60, 15, 'geo', 'mlt')
678
        return
679
680
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
681
                             [({}, 23.019629923502603),
682
                              ({"ssheight": 100000}, 23.026712036132814)])
683
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
684
        """Test mlon2mlt with scalar inputs."""
685
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
686
687
        np.testing.assert_allclose(mlt, test_mlt)
688
        assert np.asarray(mlt).shape == ()
689
        return
690
691
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
692
                             [({}, 14.705535888671875),
693
                              ({"ssheight": 100000}, 14.599319458007812)])
694
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
695
        """Test mlt2mlon with scalar inputs."""
696
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
697
698
        np.testing.assert_allclose(mlon, test_mlon)
699
        assert np.asarray(mlon).shape == ()
700
        return
701
702
    @pytest.mark.parametrize("mlon,test_mlt",
703
                             [([0, 180], [23.019261, 11.019261]),
704
                              (np.array([0, 180]), [23.019261, 11.019261]),
705
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
706
                                                      [23.019261, 11.019261]]),
707
                              (range(0, 361, 30),
708
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
709
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
710
                                19.01963, 21.01963, 23.01963])])
711
    def test_mlon2mlt_array(self, mlon, test_mlt):
712
        """Test mlon2mlt with array inputs."""
713
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
714
715
        assert mlt.shape == np.asarray(test_mlt).shape
716
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
717
        return
718
719
    @pytest.mark.parametrize("mlt,test_mlon",
720
                             [([0, 12], [14.705551, 194.705551]),
721
                              (np.array([0, 12]), [14.705551, 194.705551]),
722
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
723
                                                    [14.705551, 194.705551]]),
724
                              (range(0, 25, 2),
725
                               [14.705551, 44.705551, 74.705551, 104.705551,
726
                                134.705551, 164.705551, 194.705551, 224.705551,
727
                                254.705551, 284.705551, 314.705551, 344.705551,
728
                                14.705551])])
729
    def test_mlt2mlon_array(self, mlt, test_mlon):
730
        """Test mlt2mlon with array inputs."""
731
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
732
733
        assert mlon.shape == np.asarray(test_mlon).shape
734
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
735
        return
736
737
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
738
    def test_mlon2mlt_diffdates(self, method_name):
739
        """Test that MLT varies with universal time."""
740
        apex_method = getattr(self.apex_out, method_name)
741
        mlt1 = apex_method(0, self.in_time)
742
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
743
744
        assert mlt1 != mlt2
745
        return
746
747
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
748
    def test_mlon2mlt_offset(self, mlt_offset):
749
        """Test the time wrapping logic for the MLT."""
750
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
751
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
752
                                      self.in_time) + mlt_offset
753
754
        np.testing.assert_allclose(mlt1, mlt2)
755
        return
756
757
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
758
    def test_mlt2mlon_offset(self, mlon_offset):
759
        """Test the time wrapping logic for the magnetic longitude."""
760
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
761
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
762
                                       self.in_time) - mlon_offset
763
764
        np.testing.assert_allclose(mlon1, mlon2)
765
        return
766
767
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
768
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
769
    def test_convert_and_return(self, order, start_val):
770
        """Test the conversion to magnetic longitude or MLT and back again."""
771
        first_method = getattr(self.apex_out, "2".join(order))
772
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
773
774
        middle_val = first_method(start_val, self.in_time)
775
        end_val = second_method(middle_val, self.in_time)
776
777
        np.testing.assert_allclose(start_val, end_val)
778
        return
779
780
781
class TestApexMapMethods():
782
    """Test the Apex height mapping methods."""
783
    def setup(self):
784
        """Initialize all tests."""
785
        self.apex_out = apexpy.Apex(date=2000, refh=300)
786
787
    def teardown(self):
788
        """Clean up after each test."""
789
        del self.apex_out
790
791
    @pytest.mark.parametrize("in_args,test_mapped",
792
                             [([60, 15, 100, 10000],
793
                               [31.841466903686523, 17.916635513305664,
794
                                1.7075473124350538e-6]),
795
                              ([30, 170, 100, 500, False, 1e-2],
796
                               [25.727270126342773, 169.60546875,
797
                                0.00017573432705830783]),
798
                              ([60, 15, 100, 10000, True],
799
                               [-25.424888610839844, 27.310426712036133,
800
                                1.2074182222931995e-6]),
801
                              ([30, 170, 100, 500, True, 1e-2],
802
                               [-13.76642894744873, 164.24259948730469,
803
                                0.00056820799363777041])])
804
    def test_map_to_height(self, in_args, test_mapped):
805
        """Test the map_to_height function."""
806
        mapped = self.apex_out.map_to_height(*in_args)
807
        np.testing.assert_allclose(mapped, test_mapped, atol=1e-6)
808
        return
809
810
    def test_map_to_height_same_height(self):
811
        """Test the map_to_height function when mapping to same height."""
812
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
813
                                             precision=1e-10)
814
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
815
                                   rtol=1e-5)
816
        return
817
818
    @pytest.mark.parametrize('ivec', range(0, 4))
819
    def test_map_to_height_array_location(self, ivec):
820
        """Test map_to_height with array input."""
821
        # Set the base input and output values
822
        in_args = [60, 15, 100, 100]
823
        test_mapped = np.full(shape=(2, 3),
824
                              fill_value=[60, 15.00000381, 0.0]).transpose()
825
826
        # Update inputs for one vectorized value
827
        in_args[ivec] = [in_args[ivec], in_args[ivec]]
828
829
        # Calculate and test function
830
        mapped = self.apex_out.map_to_height(*in_args)
831
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
832
        return
833
834
    @pytest.mark.parametrize("method_name,in_args",
835
                             [("map_to_height", [0, 15, 100, 10000]),
836
                              ("map_E_to_height",
837
                               [0, 15, 100, 10000, [1, 2, 3]]),
838
                              ("map_V_to_height",
839
                               [0, 15, 100, 10000, [1, 2, 3]])])
840
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
841
        """Test map_to_height raises ApexHeightError."""
842
        apex_method = getattr(self.apex_out, method_name)
843
844
        with pytest.raises(apexpy.ApexHeightError) as aerr:
845
            apex_method(*in_args)
846
847
        assert aerr.match("is > apex height")
848
        return
849
850
    @pytest.mark.parametrize("method_name",
851
                             ["map_E_to_height", "map_V_to_height"])
852
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
853
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
854
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
855
        """Test height mapping of E/V with baddly shaped input raises Error."""
856
        apex_method = getattr(self.apex_out, method_name)
857
        in_args = [60, 15, 100, 500, ev_input]
858
        with pytest.raises(ValueError) as verr:
859
            apex_method(*in_args)
860
861
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
862
        return
863
864
    def test_mapping_EV_bad_flag(self):
865
        """Test _map_EV_to_height raises error for bad data type flag."""
866
        with pytest.raises(ValueError) as verr:
867
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
868
869
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
870
        return
871
872
    @pytest.mark.parametrize("in_args,test_mapped",
873
                             [([60, 15, 100, 500, [1, 2, 3]],
874
                               [0.71152183, 2.35624876, 0.57260784]),
875
                              ([60, 15, 100, 500, [2, 3, 4]],
876
                               [1.56028502, 3.43916636, 0.78235384]),
877
                              ([60, 15, 100, 1000, [1, 2, 3]],
878
                               [0.67796492, 2.08982134, 0.55860785]),
879
                              ([60, 15, 200, 500, [1, 2, 3]],
880
                               [0.72377397, 2.42737471, 0.59083726]),
881
                              ([60, 30, 100, 500, [1, 2, 3]],
882
                               [0.68626344, 2.37530133, 0.60060124]),
883
                              ([70, 15, 100, 500, [1, 2, 3]],
884
                               [0.72760378, 2.18082305, 0.29141979])])
885
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
886
        """Test mapping of E-field to a specified height."""
887
        mapped = self.apex_out.map_E_to_height(*in_args)
888
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
889
        return
890
891 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
892
    def test_map_E_to_height_array_location(self, ivec):
893
        """Test mapping of E-field to a specified height with array input."""
894
        # Set the base input and output values
895
        efield = np.array([[1, 2, 3]] * 2).transpose()
896
        in_args = [60, 15, 100, 500, efield]
897
        test_mapped = np.full(shape=(2, 3),
898
                              fill_value=[0.71152183, 2.35624876,
899
                                          0.57260784]).transpose()
900
901
        # Update inputs for one vectorized value if this is a location input
902
        if ivec < 4:
903
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
904
905
        # Get the mapped output and test the results
906
        mapped = self.apex_out.map_E_to_height(*in_args)
907
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
908
        return
909
910
    @pytest.mark.parametrize("in_args,test_mapped",
911
                             [([60, 15, 100, 500, [1, 2, 3]],
912
                               [0.81971957, 2.84512495, 0.69545001]),
913
                              ([60, 15, 100, 500, [2, 3, 4]],
914
                               [1.83027746, 4.14346436, 0.94764179]),
915
                              ([60, 15, 100, 1000, [1, 2, 3]],
916
                               [0.92457698, 3.14997661, 0.85135187]),
917
                              ([60, 15, 200, 500, [1, 2, 3]],
918
                               [0.80388262, 2.79321504, 0.68285158]),
919
                              ([60, 30, 100, 500, [1, 2, 3]],
920
                               [0.76141245, 2.87884673, 0.73655941]),
921
                              ([70, 15, 100, 500, [1, 2, 3]],
922
                               [0.84681866, 2.5925821,  0.34792655])])
923
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
924
        """Test mapping of velocity to a specified height."""
925
        mapped = self.apex_out.map_V_to_height(*in_args)
926
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
927
        return
928
929 View Code Duplication
    @pytest.mark.parametrize('ivec', range(0, 5))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
930
    def test_map_V_to_height_array_location(self, ivec):
931
        """Test mapping of velocity to a specified height with array input."""
932
        # Set the base input and output values
933
        evel = np.array([[1, 2, 3]] * 2).transpose()
934
        in_args = [60, 15, 100, 500, evel]
935
        test_mapped = np.full(shape=(2, 3),
936
                              fill_value=[0.81971957, 2.84512495,
937
                                          0.69545001]).transpose()
938
939
        # Update inputs for one vectorized value if this is a location input
940
        if ivec < 4:
941
            in_args[ivec] = [in_args[ivec], in_args[ivec]]
942
943
        # Get the mapped output and test the results
944
        mapped = self.apex_out.map_V_to_height(*in_args)
945
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
946
        return
947
948
949
class TestApexBasevectorMethods():
950
    """Test the Apex height base vector methods."""
951
    def setup(self):
952
        """Initialize all tests."""
953
        self.apex_out = apexpy.Apex(date=2000, refh=300)
954
        self.lat = 60
955
        self.lon = 15
956
        self.height = 100
957
        self.test_basevec = None
958
959
    def teardown(self):
960
        """Clean up after each test."""
961
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
962
963
    def get_comparison_results(self, bv_coord, coords, precision):
964
        """Get the base vector results using the hidden function for comparison.
965
966
        Parameters
967
        ----------
968
        bv_coord : str
969
            Basevector coordinate scheme, expects on of 'apex', 'qd',
970
            or 'bvectors_apex'
971
        coords : str
972
            Expects one of 'geo', 'apex', or 'qd'
973
        precision : float
974
            Float specifiying precision
975
976
        """
977
        if coords == "geo":
978
            glat = self.lat
979
            glon = self.lon
980
        else:
981
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
982
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
983
                                        precision=precision)
984
985
        if bv_coord == 'qd':
986
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
987
        elif bv_coord == 'apex':
988
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
989
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
990
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
991
        else:
992
            # These are set results that need to be updated with IGRF
993
            if coords == "geo":
994
                self.test_basevec = (
995
                    np.array([4.42368795e-05, 4.42368795e-05]),
996
                    np.array([[0.01047826, 0.01047826],
997
                              [0.33089194, 0.33089194],
998
                              [-1.04941, -1.04941]]),
999
                    np.array([5.3564698e-05, 5.3564698e-05]),
1000
                    np.array([[0.00865356, 0.00865356],
1001
                              [0.27327004, 0.27327004],
1002
                              [-0.8666646, -0.8666646]]))
1003
            elif coords == "apex":
1004
                self.test_basevec = (
1005
                    np.array([4.48672735e-05, 4.48672735e-05]),
1006
                    np.array([[-0.12510721, -0.12510721],
1007
                              [0.28945938, 0.28945938],
1008
                              [-1.1505738, -1.1505738]]),
1009
                    np.array([6.38577444e-05, 6.38577444e-05]),
1010
                    np.array([[-0.08790194, -0.08790194],
1011
                              [0.2033779, 0.2033779],
1012
                              [-0.808408, -0.808408]]))
1013
            else:
1014
                self.test_basevec = (
1015
                    np.array([4.46348578e-05, 4.46348578e-05]),
1016
                    np.array([[-0.12642345, -0.12642345],
1017
                              [0.29695055, 0.29695055],
1018
                              [-1.1517885, -1.1517885]]),
1019
                    np.array([6.38626285e-05, 6.38626285e-05]),
1020
                    np.array([[-0.08835986, -0.08835986],
1021
                              [0.20754464, 0.20754464],
1022
                              [-0.8050078, -0.8050078]]))
1023
1024
        return
1025
1026
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1027
    @pytest.mark.parametrize("coords,precision",
1028
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1029
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1030
        """Test the base vector calculations with scalars."""
1031
        # Get the base vectors
1032
        base_method = getattr(self.apex_out,
1033
                              "basevectors_{:s}".format(bv_coord))
1034
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1035
                              precision=precision)
1036
        self.get_comparison_results(bv_coord, coords, precision)
1037
        if bv_coord == "apex":
1038
            basevec = list(basevec)
1039
            for i in range(4):
1040
                # Not able to compare indices 2, 3, 4, and 5
1041
                basevec.pop(2)
1042
1043
        # Test the results
1044
        for i, vec in enumerate(basevec):
1045
            np.testing.assert_allclose(vec, self.test_basevec[i])
1046
        return
1047
1048
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1049
    def test_basevectors_scalar_shape(self, bv_coord):
1050
        """Test the shape of the scalar output."""
1051
        base_method = getattr(self.apex_out,
1052
                              "basevectors_{:s}".format(bv_coord))
1053
        basevec = base_method(self.lat, self.lon, self.height)
1054
1055
        for i, vec in enumerate(basevec):
1056
            if i < 2:
1057
                assert vec.shape == (2,)
1058
            else:
1059
                assert vec.shape == (3,)
1060
        return
1061
1062
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1063
    @pytest.mark.parametrize("ivec", range(3))
1064
    def test_basevectors_array(self, bv_coord, ivec):
1065
        """Test the output shape for array inputs."""
1066
        # Define the input arguments
1067
        in_args = [self.lat, self.lon, self.height]
1068
        in_args[ivec] = [in_args[ivec] for i in range(4)]
1069
1070
        # Get the basevectors
1071
        base_method = getattr(self.apex_out,
1072
                              "basevectors_{:s}".format(bv_coord))
1073
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1074
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1075
        if bv_coord == "apex":
1076
            basevec = list(basevec)
1077
            for i in range(4):
1078
                # Not able to compare indices 2, 3, 4, and 5
1079
                basevec.pop(2)
1080
1081
        # Evaluate the shape and the values
1082
        for i, vec in enumerate(basevec):
1083
            idim = 2 if i < 2 else 3
1084
            assert vec.shape == (idim, 4)
1085
            assert np.all(self.test_basevec[i][0] == vec[0])
1086
            assert np.all(self.test_basevec[i][1] == vec[1])
1087
        return
1088
1089
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1090
    def test_bvectors_apex(self, coords):
1091
        """Test the bvectors_apex method."""
1092
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1093
                   [self.height, self.height]]
1094
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1095
1096
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1097
                                              precision=1e-10)
1098
        for i, vec in enumerate(basevec):
1099
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1100
                                                 decimal=5)
1101
        return
1102
1103
    def test_basevectors_apex_extra_values(self):
1104
        """Test specific values in the apex base vector output."""
1105
        # Set the testing arrays
1106
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1107
                             np.array([0.939012, 0.073416, -0.07342]),
1108
                             np.array([0.055389, 1.004155, 0.257594]),
1109
                             np.array([0, 0, 1.065135])]
1110
1111
        # Get the desired output
1112
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1113
1114
        # Test the values not covered by `test_basevectors_scalar`
1115
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1116
            np.testing.assert_allclose(basevec[ibase],
1117
                                       self.test_basevec[itest], rtol=1e-4)
1118
        return
1119
1120
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1121
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1122
    def test_basevectors_apex_delta(self, lat, lon):
1123
        """Test that vectors are calculated correctly."""
1124
        # Get the apex base vectors and sort them for easy testing
1125
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1126
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1127
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1128
        gvec = [g1, g2, g3]
1129
        dvec = [d1, d2, d3]
1130
        evec = [e1, e2, e3]
1131
1132
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1133
            delta = 1 if idelta == jdelta else 0
1134
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1135
                                       delta, rtol=0, atol=1e-5)
1136
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1137
                                       delta, rtol=0, atol=1e-5)
1138
        return
1139
1140
    def test_basevectors_apex_invalid_scalar(self):
1141
        """Test warning and fill values for base vectors with bad inputs."""
1142
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1143
        invalid = np.full(shape=(3,), fill_value=np.nan)
1144
1145
        # Get the output and the warnings
1146
        with warnings.catch_warnings(record=True) as warn_rec:
1147
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1148
1149
        for i, bvec in enumerate(basevec):
1150
            if i < 2:
1151
                assert not np.allclose(bvec, invalid[:2])
1152
            else:
1153
                np.testing.assert_allclose(bvec, invalid)
1154
1155
        assert issubclass(warn_rec[-1].category, UserWarning)
1156
        assert 'set to NaN where' in str(warn_rec[-1].message)
1157
        return
1158
1159
1160
class TestApexGetMethods():
1161
    """Test the Apex `get` methods."""
1162
    def setup(self):
1163
        """Initialize all tests."""
1164
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1165
1166
    def teardown(self):
1167
        """Clean up after each test."""
1168
        del self.apex_out
1169
1170
    @pytest.mark.parametrize("alat, aheight", [(10, 507.409702543805),
1171
                                               (60, 20313.026999999987)])
1172
    def test_get_apex(self, alat, aheight):
1173
        """Test the apex height retrieval results."""
1174
        alt = self.apex_out.get_apex(alat)
1175
        np.testing.assert_allclose(alt, aheight)
1176
        return
1177
1178
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1179
                             [([80], [100], [300], 5.100682377815247e-05),
1180
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1181
                               np.array([4.18657154e-05, 5.11118114e-05,
1182
                                         4.91969854e-05, 5.10519207e-05,
1183
                                         4.90054816e-05])),
1184
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1185
    def test_get_babs(self, glat, glon, height, test_bmag):
1186
        """Test the method to get the magnitude of the magnetic field."""
1187
        bmag = self.apex_out.get_babs(glat, glon, height)
1188
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1189
        return
1190
1191
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1192
    def test_get_with_invalid_lat(self, bad_lat):
1193
        """Test get methods raise ValueError for invalid latitudes."""
1194
1195
        with pytest.raises(ValueError) as verr:
1196
            self.apex_out.get_apex(bad_lat)
1197
1198
        assert str(verr.value).find("must be in [-90, 90]") > 0
1199
        return
1200
1201
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1202
    def test_get_with_invalid_lat(self, bad_lat):
1203
        """Test get methods raise ValueError for invalid latitudes."""
1204
1205
        with pytest.raises(ValueError) as verr:
1206
            self.apex_out.get_babs(bad_lat, 15, 100)
1207
1208
        assert str(verr.value).find("must be in [-90, 90]") > 0
1209
        return
1210
1211
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1212
    def test_get_at_lat_boundary(self, bound_lat):
1213
        """Test get methods at the latitude boundary, with allowed excess."""
1214
        # Get a latitude just beyond the limit
1215
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1216
1217
        # Get the two outputs, slight tolerance outside of boundary allowed
1218
        bound_out = self.apex_out.get_apex(bound_lat)
1219
        excess_out = self.apex_out.get_apex(excess_lat)
1220
1221
        # Test the outputs
1222
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1223
        return
1224