Passed
Pull Request — main (#103)
by Angeline
01:45
created

test_Apex.igrf_file()   B

Complexity

Conditions 5

Size

Total Lines 33
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 1
dl 0
loc 33
rs 8.9332
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    def setup_method(self):
72
        self.apex_out = None
73
        self.test_date = dt.datetime.utcnow()
74
        self.test_refh = 0
75
        self.bad_file = 'foo/path/to/datafile.blah'
76
77
    def teardown_method(self):
78
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
79
80
    def eval_date(self):
81
        """Evaluate the times in self.test_date and self.apex_out."""
82
        if isinstance(self.test_date, dt.datetime) \
83
           or isinstance(self.test_date, dt.date):
84
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
85
86
        # Assert the times are the same on the order of tens of seconds.
87
        # Necessary to evaluate the current UTC
88
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
89
        return
90
91
    def eval_refh(self):
92
        """Evaluate the reference height in self.refh and self.apex_out."""
93
        eval_str = "".join(["expected reference height [",
94
                            "{:}] not equal to Apex ".format(self.test_refh),
95
                            "reference height ",
96
                            "[{:}]".format(self.apex_out.refh)])
97
        assert self.test_refh == self.apex_out.refh, eval_str
98
        return
99
100
    def test_init_defaults(self):
101
        """Test Apex class default initialization."""
102
        self.apex_out = apexpy.Apex()
103
        self.eval_date()
104
        self.eval_refh()
105
        return
106
107
    @pytest.mark.parametrize("in_date",
108
                             [2015, 2015.5, dt.date(2015, 1, 1),
109
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
110
    def test_init_date(self, in_date):
111
        """Test Apex class with date initialization.
112
113
        Parameters
114
        ----------
115
        in_date : int, float, dt.date, or dt.datetime
116
            Input date in a variety of formats
117
118
        """
119
        self.test_date = in_date
120
        self.apex_out = apexpy.Apex(date=self.test_date)
121
        self.eval_date()
122
        self.eval_refh()
123
        return
124
125
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
126
    def test_set_epoch(self, new_date):
127
        """Test successful setting of Apex epoch after initialization.
128
129
        Parameters
130
        ----------
131
        new_date : int or float
132
            New date for the Apex class
133
134
        """
135
        # Evaluate the default initialization
136
        self.apex_out = apexpy.Apex()
137
        self.eval_date()
138
        self.eval_refh()
139
140
        # Update the epoch
141
        ref_apex = eval(self.apex_out.__repr__())
142
        self.apex_out.set_epoch(new_date)
143
        assert ref_apex != self.apex_out
144
        self.test_date = new_date
145
        self.eval_date()
146
        return
147
148
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
149
    def test_init_refh(self, in_refh):
150
        """Test Apex class with reference height initialization.
151
152
        Parameters
153
        ----------
154
        in_refh : float
155
            Input reference height in km
156
157
        """
158
        self.test_refh = in_refh
159
        self.apex_out = apexpy.Apex(refh=self.test_refh)
160
        self.eval_date()
161
        self.eval_refh()
162
        return
163
164
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
165
    def test_set_refh(self, new_refh):
166
        """Test the method used to set the reference height after the init.
167
168
        Parameters
169
        ----------
170
        new_refh : float
171
            Reference height in km
172
173
        """
174
        # Verify the defaults are set
175
        self.apex_out = apexpy.Apex(date=self.test_date)
176
        self.eval_date()
177
        self.eval_refh()
178
179
        # Update to a new reference height and test
180
        ref_apex = eval(self.apex_out.__repr__())
181
        self.apex_out.set_refh(new_refh)
182
183
        if self.test_refh == new_refh:
184
            assert ref_apex == self.apex_out
185
        else:
186
            assert ref_apex != self.apex_out
187
            self.test_refh = new_refh
188
        self.eval_refh()
189
        return
190
191
    def test_init_with_bad_datafile(self):
192
        """Test raises IOError with non-existent datafile input."""
193
        with pytest.raises(IOError) as oerr:
194
            apexpy.Apex(datafile=self.bad_file)
195
        assert str(oerr.value).startswith('Data file does not exist')
196
        return
197
198
    def test_init_with_bad_fortranlib(self):
199
        """Test raises IOError with non-existent datafile input."""
200
        with pytest.raises(IOError) as oerr:
201
            apexpy.Apex(fortranlib=self.bad_file)
202
        assert str(oerr.value).startswith('Fortran library does not exist')
203
        return
204
205
    def test_repr_eval(self):
206
        """Test the Apex.__repr__ results."""
207
        # Initialize the apex object
208
        self.apex_out = apexpy.Apex()
209
        self.eval_date()
210
        self.eval_refh()
211
212
        # Get and test the repr string
213
        out_str = self.apex_out.__repr__()
214
        assert out_str.find("apexpy.Apex(") == 0
215
216
        # Test the ability to re-create the apex object from the repr string
217
        new_apex = eval(out_str)
218
        assert new_apex == self.apex_out
219
        return
220
221
    def test_ne_other_class(self):
222
        """Test Apex class inequality to a different class."""
223
        self.apex_out = apexpy.Apex()
224
        self.eval_date()
225
        self.eval_refh()
226
227
        assert self.apex_out != self.test_date
228
        return
229
230
    def test_ne_missing_attr(self):
231
        """Test Apex class inequality when attributes are missing from one."""
232
        self.apex_out = apexpy.Apex()
233
        self.eval_date()
234
        self.eval_refh()
235
        ref_apex = eval(self.apex_out.__repr__())
236
        del ref_apex.RE
237
238
        assert ref_apex != self.apex_out
239
        assert self.apex_out != ref_apex
240
        return
241
242
    def test_eq_missing_attr(self):
243
        """Test Apex class equality when attributes are missing from both."""
244
        self.apex_out = apexpy.Apex()
245
        self.eval_date()
246
        self.eval_refh()
247
        ref_apex = eval(self.apex_out.__repr__())
248
        del ref_apex.RE, self.apex_out.RE
249
250
        assert ref_apex == self.apex_out
251
        return
252
253
    def test_str_eval(self):
254
        """Test the Apex.__str__ results."""
255
        # Initialize the apex object
256
        self.apex_out = apexpy.Apex()
257
        self.eval_date()
258
        self.eval_refh()
259
260
        # Get and test the printed string
261
        out_str = self.apex_out.__str__()
262
        assert out_str.find("Decimal year") > 0
263
        return
264
265
266
class TestApexMethod(object):
267
    """Test the Apex methods."""
268
    def setup_method(self):
269
        """Initialize all tests."""
270
        self.apex_out = apexpy.Apex(date=2000, refh=300)
271
        self.in_lat = 60
272
        self.in_lon = 15
273
        self.in_alt = 100
274
275
    def teardown_method(self):
276
        """Clean up after each test."""
277
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
278
279
    def get_input_args(self, method_name, precision=0.0):
280
        """Set the input arguments for the different Apex methods.
281
282
        Parameters
283
        ----------
284
        method_name : str
285
            Name of the Apex class method
286
        precision : float
287
            Value for the precision (default=0.0)
288
289
        Returns
290
        -------
291
        in_args : list
292
            List of the appropriate input arguments
293
294
        """
295
        in_args = [self.in_lat, self.in_lon, self.in_alt]
296
297
        # Add precision, if needed
298
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
299
                           "_apex2geo"]:
300
            in_args.append(precision)
301
302
        # Add a reference height, if needed
303
        if method_name in ["apxg2all"]:
304
            in_args.append(300)
305
306
        # Add a vector flag, if needed
307
        if method_name in ["apxg2all", "apxg2q"]:
308
            in_args.append(1)
309
310
        return in_args
311
312
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
313
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
314
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
315
                              ("_qd2geo", "apxq2g", slice(None)),
316
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
317
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
318
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
319
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
320
                                  lat, lon):
321
        """Tests Apex/fortran interface consistency for scalars.
322
323
        Parameters
324
        ----------
325
        apex_method : str
326
            Name of the Apex class method to test
327
        fortran_method : str
328
            Name of the Fortran function to test
329
        fslice : slice
330
            Slice used select the appropriate Fortran outputs
331
        lat : int or float
332
            Latitude in degrees N
333
        lon : int or float
334
            Longitude in degrees E
335
336
        """
337
        # Set the input coordinates
338
        self.in_lat = lat
339
        self.in_lon = lon
340
341
        # Get the Apex class method and the fortran function call
342
        apex_func = getattr(self.apex_out, apex_method)
343
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
344
345
        # Get the appropriate input arguments
346
        apex_args = self.get_input_args(apex_method)
347
        fortran_args = self.get_input_args(fortran_method)
348
349
        # Evaluate the equivalent function calls
350
        np.testing.assert_allclose(apex_func(*apex_args),
351
                                   fortran_func(*fortran_args)[fslice])
352
        return
353
354
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
355
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
356
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
357
                              ("_qd2geo", "apxq2g", slice(None)),
358
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
359
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
360
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
361
                                           (180, -180), (-180, 180),
362
                                           (-345, 15), (375, 15)])
363
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
364
                                        fslice, lat, lon1, lon2):
365
        """Tests Apex/fortran interface consistency for longitude rollover.
366
367
        Parameters
368
        ----------
369
        apex_method : str
370
            Name of the Apex class method to test
371
        fortran_method : str
372
            Name of the Fortran function to test
373
        fslice : slice
374
            Slice used select the appropriate Fortran outputs
375
        lat : int or float
376
            Latitude in degrees N
377
        lon1 : int or float
378
            Longitude in degrees E
379
        lon2 : int or float
380
            Equivalent longitude in degrees E
381
382
        """
383
        # Set the fixed input coordinate
384
        self.in_lat = lat
385
386
        # Get the Apex class method and the fortran function call
387
        apex_func = getattr(self.apex_out, apex_method)
388
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
389
390
        # Get the appropriate input arguments
391
        self.in_lon = lon1
392
        apex_args = self.get_input_args(apex_method)
393
394
        self.in_lon = lon2
395
        fortran_args = self.get_input_args(fortran_method)
396
397
        # Evaluate the equivalent function calls
398
        np.testing.assert_allclose(apex_func(*apex_args),
399
                                   fortran_func(*fortran_args)[fslice])
400
        return
401
402
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
403
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
404
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
405
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
406
                              ("_qd2geo", "apxq2g", slice(None)),
407
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
408
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
409
                                 fslice):
410
        """Tests Apex/fortran interface consistency for array input.
411
412
        Parameters
413
        ----------
414
        arr_shape : tuple
415
            Expected output shape
416
        apex_method : str
417
            Name of the Apex class method to test
418
        fortran_method : str
419
            Name of the Fortran function to test
420
        fslice : slice
421
            Slice used select the appropriate Fortran outputs
422
423
        """
424
        # Get the Apex class method and the fortran function call
425
        apex_func = getattr(self.apex_out, apex_method)
426
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
427
428
        # Set up the input arrays
429
        ref_lat = np.array([0, 30, 60, 90])
430
        ref_alt = np.array([100, 200, 300, 400])
431
        self.in_lat = ref_lat.reshape(arr_shape)
432
        self.in_alt = ref_alt.reshape(arr_shape)
433
        apex_args = self.get_input_args(apex_method)
434
435
        # Get the Apex class results
436
        aret = apex_func(*apex_args)
437
438
        # Get the fortran function results
439
        flats = list()
440
        flons = list()
441
442
        for i, lat in enumerate(ref_lat):
443
            self.in_lat = lat
444
            self.in_alt = ref_alt[i]
445
            fortran_args = self.get_input_args(fortran_method)
446
            fret = fortran_func(*fortran_args)[fslice]
447
            flats.append(fret[0])
448
            flons.append(fret[1])
449
450
        flats = np.array(flats)
451
        flons = np.array(flons)
452
453
        # Evaluate results
454
        try:
455
            # This returned value is array of floats
456
            np.testing.assert_allclose(aret[0].astype(float),
457
                                       flats.reshape(arr_shape).astype(float))
458
            np.testing.assert_allclose(aret[1].astype(float),
459
                                       flons.reshape(arr_shape).astype(float))
460
        except ValueError:
461
            # This returned value is array of arrays
462
            alats = aret[0].reshape((4,))
463
            alons = aret[1].reshape((4,))
464
            for i, flat in enumerate(flats):
465
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
466
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
467
468
        return
469
470
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
471
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
472
    def test_geo2apexall_scalar(self, lat, lon):
473
        """Test Apex/fortran geo2apexall interface consistency for scalars.
474
475
        Parameters
476
        ----------
477
        lat : int or float
478
            Latitude in degrees N
479
        long : int or float
480
            Longitude in degrees E
481
482
        """
483
        # Get the Apex and Fortran results
484
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
485
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
486
487
        # Evaluate each element in the results
488
        for aval, fval in zip(aret, fret):
489
            np.testing.assert_allclose(aval, fval)
490
491
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
492
    def test_geo2apexall_array(self, arr_shape):
493
        """Test Apex/fortran geo2apexall interface consistency for arrays.
494
495
        Parameters
496
        ----------
497
        arr_shape : tuple
498
            Expected output shape
499
500
        """
501
        # Set the input
502
        self.in_lat = np.array([0, 30, 60, 90])
503
        self.in_alt = np.array([100, 200, 300, 400])
504
505
        # Get the Apex class results
506
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
507
                                          self.in_lon,
508
                                          self.in_alt.reshape(arr_shape))
509
510
        # For each lat/alt pair, get the Fortran results
511
        fret = list()
512
        for i, lat in enumerate(self.in_lat):
513
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
514
                                                    self.in_alt[i], 300, 1))
515
516
        # Cycle through all returned values
517
        for i, ret in enumerate(aret):
518
            try:
519
                # This returned value is array of floats
520
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
521
                                      fret[3][i]]).reshape(arr_shape)
522
                np.testing.assert_allclose(ret.astype(float),
523
                                           fret_test.astype(float))
524
            except ValueError:
525
                # This returned value is array of arrays
526
                ret = ret.reshape((4,))
527
                for j, single_fret in enumerate(fret):
528
                    np.testing.assert_allclose(ret[j], single_fret[i])
529
        return
530
531
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
532
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
533
    def test_convert_consistency(self, in_coord, out_coord):
534
        """Test the self-consistency of the Apex convert method.
535
536
        Parameters
537
        ----------
538
        in_coord : str
539
            Input coordinate system
540
        out_coord : str
541
            Output coordinate system
542
543
        """
544
        if in_coord == out_coord:
545
            pytest.skip("Test not needed for same src and dest coordinates")
546
547
        # Define the method name
548
        method_name = "2".join([in_coord, out_coord])
549
550
        # Get the method and method inputs
551
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
552
        apex_args = self.get_input_args(method_name)
553
        apex_method = getattr(self.apex_out, method_name)
554
555
        # Define the slice needed to get equivalent output from the named method
556
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
557
558
        # Get output using convert and named method
559
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
560
                                            out_coord, **convert_kwargs)
561
        method_out = apex_method(*apex_args)[mslice]
562
563
        # Compare both outputs, should be identical
564
        np.testing.assert_allclose(convert_out, method_out)
565
        return
566
567
    @pytest.mark.parametrize("bound_lat", [90, -90])
568
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
569
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
570
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
571
        """Test the conversion at the latitude boundary, with allowed excess.
572
573
        Parameters
574
        ----------
575
        bound_lat : int or float
576
            Boundary latitude in degrees N
577
        in_coord : str
578
            Input coordinate system
579
        out_coord : str
580
            Output coordinate system
581
582
        """
583
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
584
585
        # Get the two outputs, slight tolerance outside of boundary allowed
586
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
587
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
588
589
        # Test the outputs
590
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
591
        return
592
593
    def test_convert_qd2apex_at_equator(self):
594
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
595
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
596
                                       height=320.0)
597
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
598
                                          dest='apex', height=320.0)
599
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
600
        return
601
602
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
603
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
604
    def test_convert_withnan(self, src, dest):
605
        """Test Apex.convert success with NaN input.
606
607
        Parameters
608
        ----------
609
        src : str
610
            Input coordinate system
611
        dest : str
612
            Output coordinate system
613
614
        """
615
        if src == dest:
616
            pytest.skip("Test not needed for same src and dest coordinates")
617
618
        num_nans = 5
619
        in_loc = np.arange(0, 10, dtype=float)
620
        in_loc[:num_nans] = np.nan
621
622
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
623
624
        for out in out_loc:
625
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
626
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
627
628
        return
629
630
    @pytest.mark.parametrize("bad_lat", [91, -91])
631
    def test_convert_invalid_lat(self, bad_lat):
632
        """Test convert raises ValueError for invalid latitudes.
633
634
        Parameters
635
        ----------
636
        bad_lat : int or float
637
            Latitude ouside the supported range in degrees N
638
639
        """
640
641
        with pytest.raises(ValueError) as verr:
642
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
643
644
        assert str(verr.value).find("must be in [-90, 90]") > 0
645
        return
646
647
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
648
                                        ("geo", "mlt")])
649
    def test_convert_invalid_transformation(self, coords):
650
        """Test raises NotImplementedError for bad coordinates.
651
652
        Parameters
653
        ----------
654
        coords : tuple
655
            Tuple specifying the input and output coordinate systems
656
657
        """
658
        if "mlt" in coords:
659
            estr = "datetime must be given for MLT calculations"
660
        else:
661
            estr = "Unknown coordinate transformation"
662
663
        with pytest.raises(ValueError) as verr:
664
            self.apex_out.convert(0, 0, *coords)
665
666
        assert str(verr).find(estr) >= 0
667
        return
668
669
    @pytest.mark.parametrize("method_name, out_comp",
670
                             [("geo2apex",
671
                               (55.94841766357422, 94.10684204101562)),
672
                              ("apex2geo",
673
                               (51.476322174072266, -66.22817993164062,
674
                                5.727287771151168e-06)),
675
                              ("geo2qd",
676
                               (56.531288146972656, 94.10684204101562)),
677
                              ("apex2qd", (60.498401178276744, 15.0)),
678
                              ("qd2apex", (59.49138097045895, 15.0))])
679
    def test_method_scalar_input(self, method_name, out_comp):
680
        """Test the user method against set values with scalars.
681
682
        Parameters
683
        ----------
684
        method_name : str
685
            Apex class method to be tested
686
        out_comp : tuple of floats
687
            Expected output values
688
689
        """
690
        # Get the desired methods
691
        user_method = getattr(self.apex_out, method_name)
692
693
        # Get the user output
694
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
695
696
        # Evaluate the user output
697
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
698
699
        for out_val in user_out:
700
            assert np.asarray(out_val).shape == (), "output is not a scalar"
701
        return
702
703
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
704
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
705
    @pytest.mark.parametrize("method_args, out_shape",
706
                             [([[60, 60], 15, 100], (2,)),
707
                              ([60, [15, 15], 100], (2,)),
708
                              ([60, 15, [100, 100]], (2,)),
709
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
710
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
711
                                    out_shape):
712
        """Test the user method with inputs that require some broadcasting.
713
714
        Parameters
715
        ----------
716
        in_coord : str
717
            Input coordiante system
718
        out_coord : str
719
            Output coordiante system
720
        method_args : list
721
            List of input arguments
722
        out_shape : tuple
723
            Expected shape of output values
724
725
        """
726
        if in_coord == out_coord:
727
            pytest.skip("Test not needed for same src and dest coordinates")
728
729
        # Get the desired methods
730
        method_name = "2".join([in_coord, out_coord])
731
        user_method = getattr(self.apex_out, method_name)
732
733
        # Get the user output
734
        user_out = user_method(*method_args)
735
736
        # Evaluate the user output
737
        for out_val in user_out:
738
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
739
            assert out_val.shape == out_shape
740
        return
741
742
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
743
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
744
    @pytest.mark.parametrize("bad_lat", [91, -91])
745
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
746
        """Test convert raises ValueError for invalid latitudes.
747
748
        Parameters
749
        ----------
750
        in_coord : str
751
            Input coordiante system
752
        out_coord : str
753
            Output coordiante system
754
        bad_lat : int
755
            Latitude in degrees N that is out of bounds
756
757
        """
758
        if in_coord == out_coord:
759
            pytest.skip("Test not needed for same src and dest coordinates")
760
761
        # Get the desired methods
762
        method_name = "2".join([in_coord, out_coord])
763
        user_method = getattr(self.apex_out, method_name)
764
765
        with pytest.raises(ValueError) as verr:
766
            user_method(bad_lat, 15, 100)
767
768
        assert str(verr.value).find("must be in [-90, 90]") > 0
769
        return
770
771
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
772
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
773
    @pytest.mark.parametrize("bound_lat", [90, -90])
774
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
775
        """Test user methods at the latitude boundary, with allowed excess.
776
777
        Parameters
778
        ----------
779
        in_coord : str
780
            Input coordiante system
781
        out_coord : str
782
            Output coordiante system
783
        bad_lat : int
784
            Latitude in degrees N that is at the limits of the boundary
785
786
        """
787
        if in_coord == out_coord:
788
            pytest.skip("Test not needed for same src and dest coordinates")
789
790
        # Get the desired methods
791
        method_name = "2".join([in_coord, out_coord])
792
        user_method = getattr(self.apex_out, method_name)
793
794
        # Get a latitude just beyond the limit
795
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
796
797
        # Get the two outputs, slight tolerance outside of boundary allowed
798
        bound_out = user_method(bound_lat, 0, 100)
799
        excess_out = user_method(excess_lat, 0, 100)
800
801
        # Test the outputs
802
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
803
        return
804
805
    def test_geo2apex_undefined_warning(self):
806
        """Test geo2apex warning and fill values for an undefined location."""
807
808
        # Update the apex object
809
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
810
811
        # Get the output and the warnings
812
        with warnings.catch_warnings(record=True) as warn_rec:
813
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
814
815
        assert np.isnan(user_lat)
816
        assert np.isfinite(user_lon)
817
        assert len(warn_rec) == 1
818
        assert issubclass(warn_rec[-1].category, UserWarning)
819
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
820
        return
821
822
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
823
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
824
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
825
        """Test quasi-dipole success with a height close to the reference.
826
827
        Parameters
828
        ----------
829
        method_name : str
830
            Apex class method name to be tested
831
        delta_h : float
832
            tolerance for height in km
833
834
        """
835
        qd_method = getattr(self.apex_out, method_name)
836
        in_args = [0, 15, self.apex_out.refh + delta_h]
837
        out_coords = qd_method(*in_args)
838
839
        for i, out_val in enumerate(out_coords):
840
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
841
        return
842
843
    @pytest.mark.parametrize("method_name, hinc, msg",
844
                             [("apex2qd", 1.0, "is > apex height"),
845
                              ("qd2apex", -1.0, "is < reference height")])
846
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
847
        """Quasi-dipole raises ApexHeightError when height above reference.
848
849
        Parameters
850
        ----------
851
        method_name : str
852
            Apex class method name to be tested
853
        hinc : float
854
            Height increment in km
855
        msg : str
856
            Expected output message
857
858
        """
859
        qd_method = getattr(self.apex_out, method_name)
860
861
        with pytest.raises(apexpy.ApexHeightError) as aerr:
862
            qd_method(0, 15, self.apex_out.refh + hinc)
863
864
        assert str(aerr).find(msg) > 0
865
        return
866
867
868
class TestApexMLTMethods(object):
869
    """Test the Apex Magnetic Local Time (MLT) methods."""
870
    def setup_method(self):
871
        """Initialize all tests."""
872
        self.apex_out = apexpy.Apex(date=2000, refh=300)
873
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
874
875
    def teardown_method(self):
876
        """Clean up after each test."""
877
        del self.apex_out, self.in_time
878
879
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
880
    def test_convert_to_mlt(self, in_coord):
881
        """Test the conversions to MLT using Apex convert.
882
883
        Parameters
884
        ----------
885
        in_coord : str
886
            Input coordinate system
887
888
        """
889
890
        # Get the magnetic longitude from the appropriate method
891
        if in_coord == "geo":
892
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
893
            mlon = apex_method(60, 15, 100)[1]
894
        else:
895
            mlon = 15
896
897
        # Get the output MLT values
898
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
899
                                            height=100, ssheight=2e5,
900
                                            datetime=self.in_time)[1]
901
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
902
903
        # Test the outputs
904
        np.testing.assert_allclose(convert_mlt, method_mlt)
905
        return
906
907
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
908
    def test_convert_mlt_to_lon(self, out_coord):
909
        """Test the conversions from MLT using Apex convert.
910
911
        Parameters
912
        ----------
913
        out_coord : str
914
            Output coordinate system
915
916
        """
917
        # Get the output longitudes
918
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
919
                                            height=100, ssheight=2e5,
920
                                            datetime=self.in_time,
921
                                            precision=1e-2)
922
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
923
924
        if out_coord == "geo":
925
            method_out = self.apex_out.apex2geo(60, mlon, 100,
926
                                                precision=1e-2)[:-1]
927
        elif out_coord == "qd":
928
            method_out = self.apex_out.apex2qd(60, mlon, 100)
929
        else:
930
            method_out = (60, mlon)
931
932
        # Evaluate the outputs
933
        np.testing.assert_allclose(convert_out, method_out)
934
        return
935
936
    def test_convert_geo2mlt_nodate(self):
937
        """Test convert from geo to MLT raises ValueError with no datetime."""
938
        with pytest.raises(ValueError):
939
            self.apex_out.convert(60, 15, 'geo', 'mlt')
940
        return
941
942
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
943
                             [({}, 23.019629923502603),
944
                              ({"ssheight": 100000}, 23.026712036132814)])
945
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
946
        """Test mlon2mlt with scalar inputs.
947
948
        Parameters
949
        ----------
950
        mlon_kwargs : dict
951
            Input kwargs
952
        test_mlt : float
953
            Output MLT in hours
954
955
        """
956
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
957
958
        np.testing.assert_allclose(mlt, test_mlt)
959
        assert np.asarray(mlt).shape == ()
960
        return
961
962
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
963
                             [({}, 14.705535888671875),
964
                              ({"ssheight": 100000}, 14.599319458007812)])
965
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
966
        """Test mlt2mlon with scalar inputs.
967
968
        Parameters
969
        ----------
970
        mlt_kwargs : dict
971
            Input kwargs
972
        test_mlon : float
973
            Output longitude in degrees E
974
975
        """
976
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
977
978
        np.testing.assert_allclose(mlon, test_mlon)
979
        assert np.asarray(mlon).shape == ()
980
        return
981
982
    @pytest.mark.parametrize("mlon,test_mlt",
983
                             [([0, 180], [23.019261, 11.019261]),
984
                              (np.array([0, 180]), [23.019261, 11.019261]),
985
                              (np.array([[0], [180]]),
986
                               np.array([[23.019261], [11.019261]])),
987
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
988
                                                      [23.019261, 11.019261]]),
989
                              (range(0, 361, 30),
990
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
991
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
992
                                19.01963, 21.01963, 23.01963])])
993
    def test_mlon2mlt_array(self, mlon, test_mlt):
994
        """Test mlon2mlt with array inputs.
995
996
        Parameters
997
        ----------
998
        mlon : array-like
999
            Input longitudes in degrees E
1000
        test_mlt : float
1001
            Output MLT in hours
1002
1003
        """
1004
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1005
1006
        assert mlt.shape == np.asarray(test_mlt).shape
1007
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1008
        return
1009
1010
    @pytest.mark.parametrize("mlt,test_mlon",
1011
                             [([0, 12], [14.705551, 194.705551]),
1012
                              (np.array([0, 12]), [14.705551, 194.705551]),
1013
                              (np.array([[0], [12]]),
1014
                               np.array([[14.705551], [194.705551]])),
1015
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1016
                                                    [14.705551, 194.705551]]),
1017
                              (range(0, 25, 2),
1018
                               [14.705551, 44.705551, 74.705551, 104.705551,
1019
                                134.705551, 164.705551, 194.705551, 224.705551,
1020
                                254.705551, 284.705551, 314.705551, 344.705551,
1021
                                14.705551])])
1022
    def test_mlt2mlon_array(self, mlt, test_mlon):
1023
        """Test mlt2mlon with array inputs.
1024
1025
        Parameters
1026
        ----------
1027
        mlt : array-like
1028
            Input MLT in hours
1029
        test_mlon : float
1030
            Output longitude in degrees E
1031
1032
        """
1033
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1034
1035
        assert mlon.shape == np.asarray(test_mlon).shape
1036
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1037
        return
1038
1039
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1040
    def test_mlon2mlt_diffdates(self, method_name):
1041
        """Test that MLT varies with universal time.
1042
1043
        Parameters
1044
        ----------
1045
        method_name : str
1046
            Name of Apex class method to be tested
1047
1048
        """
1049
        apex_method = getattr(self.apex_out, method_name)
1050
        mlt1 = apex_method(0, self.in_time)
1051
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1052
1053
        assert mlt1 != mlt2
1054
        return
1055
1056
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1057
    def test_mlon2mlt_offset(self, mlt_offset):
1058
        """Test the time wrapping logic for the MLT.
1059
1060
        Parameters
1061
        ----------
1062
        mlt_offset : float
1063
            MLT offset in hours
1064
1065
        """
1066
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1067
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1068
                                      self.in_time) + mlt_offset
1069
1070
        np.testing.assert_allclose(mlt1, mlt2)
1071
        return
1072
1073
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1074
    def test_mlt2mlon_offset(self, mlon_offset):
1075
        """Test the time wrapping logic for the magnetic longitude.
1076
1077
        Parameters
1078
        ----------
1079
        mlt_offset : float
1080
            MLT offset in hours
1081
1082
        """
1083
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1084
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1085
                                       self.in_time) - mlon_offset
1086
1087
        np.testing.assert_allclose(mlon1, mlon2)
1088
        return
1089
1090
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1091
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1092
    def test_convert_and_return(self, order, start_val):
1093
        """Test the conversion to magnetic longitude or MLT and back again.
1094
1095
        Parameters
1096
        ----------
1097
        order : list
1098
            List of strings specifying the order to run functions
1099
        start_val : int or float
1100
            Input value
1101
1102
        """
1103
        first_method = getattr(self.apex_out, "2".join(order))
1104
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1105
1106
        middle_val = first_method(start_val, self.in_time)
1107
        end_val = second_method(middle_val, self.in_time)
1108
1109
        np.testing.assert_allclose(start_val, end_val)
1110
        return
1111
1112
1113
class TestApexMapMethods(object):
1114
    """Test the Apex height mapping methods."""
1115
    def setup_method(self):
1116
        """Initialize all tests."""
1117
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1118
1119
    def teardown_method(self):
1120
        """Clean up after each test."""
1121
        del self.apex_out
1122
1123
    @pytest.mark.parametrize("in_args,test_mapped",
1124
                             [([60, 15, 100, 10000],
1125
                               [31.841466903686523, 17.916635513305664,
1126
                                1.7075473124350538e-6]),
1127
                              ([30, 170, 100, 500, False, 1e-2],
1128
                               [25.727270126342773, 169.60546875,
1129
                                0.00017573432705830783]),
1130
                              ([60, 15, 100, 10000, True],
1131
                               [-25.424888610839844, 27.310426712036133,
1132
                                1.2074182222931995e-6]),
1133
                              ([30, 170, 100, 500, True, 1e-2],
1134
                               [-13.76642894744873, 164.24259948730469,
1135
                                0.00056820799363777041])])
1136
    def test_map_to_height(self, in_args, test_mapped):
1137
        """Test the map_to_height function.
1138
1139
        Parameters
1140
        ----------
1141
        in_args : list
1142
            List of input arguments
1143
        test_mapped : list
1144
            List of expected outputs
1145
1146
        """
1147
        mapped = self.apex_out.map_to_height(*in_args)
1148
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1149
        return
1150
1151
    def test_map_to_height_same_height(self):
1152
        """Test the map_to_height function when mapping to same height."""
1153
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1154
                                             precision=1e-10)
1155
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1156
                                   rtol=1e-5, atol=1e-5)
1157
        return
1158
1159
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1160
    @pytest.mark.parametrize('ivec', range(0, 4))
1161
    def test_map_to_height_array_location(self, arr_shape, ivec):
1162
        """Test map_to_height with array input.
1163
1164
        Parameters
1165
        ----------
1166
        arr_shape : tuple
1167
            Expected array shape
1168
        ivec : int
1169
            Input argument index for vectorized input
1170
1171
        """
1172
        # Set the base input and output values
1173
        in_args = [60, 15, 100, 100]
1174
        test_mapped = [60, 15.00000381, 0.0]
1175
1176
        # Update inputs for one vectorized value
1177
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1178
1179
        # Calculate and test function
1180
        mapped = self.apex_out.map_to_height(*in_args)
1181
        for i, test_val in enumerate(test_mapped):
1182
            assert mapped[i].shape == arr_shape
1183
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1184
                                       atol=1e-5)
1185
        return
1186
1187
    @pytest.mark.parametrize("method_name,in_args",
1188
                             [("map_to_height", [0, 15, 100, 10000]),
1189
                              ("map_E_to_height",
1190
                               [0, 15, 100, 10000, [1, 2, 3]]),
1191
                              ("map_V_to_height",
1192
                               [0, 15, 100, 10000, [1, 2, 3]])])
1193
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1194
        """Test map_to_height raises ApexHeightError.
1195
1196
        Parameters
1197
        ----------
1198
        method_name : str
1199
            Name of the Apex class method to test
1200
        in_args : list
1201
            List of input arguments
1202
1203
        """
1204
        apex_method = getattr(self.apex_out, method_name)
1205
1206
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1207
            apex_method(*in_args)
1208
1209
        assert aerr.match("is > apex height")
1210
        return
1211
1212
    @pytest.mark.parametrize("method_name",
1213
                             ["map_E_to_height", "map_V_to_height"])
1214
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1215
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1216
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1217
        """Test height mapping of E/V with baddly shaped input raises Error.
1218
1219
        Parameters
1220
        ----------
1221
        method_name : str
1222
            Name of the Apex class method to test
1223
        ev_input : list
1224
            E/V input arguments
1225
1226
        """
1227
        apex_method = getattr(self.apex_out, method_name)
1228
        in_args = [60, 15, 100, 500, ev_input]
1229
        with pytest.raises(ValueError) as verr:
1230
            apex_method(*in_args)
1231
1232
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1233
        return
1234
1235
    def test_mapping_EV_bad_flag(self):
1236
        """Test _map_EV_to_height raises error for bad data type flag."""
1237
        with pytest.raises(ValueError) as verr:
1238
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1239
1240
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1241
        return
1242
1243
    @pytest.mark.parametrize("in_args,test_mapped",
1244
                             [([60, 15, 100, 500, [1, 2, 3]],
1245
                               [0.71152183, 2.35624876, 0.57260784]),
1246
                              ([60, 15, 100, 500, [2, 3, 4]],
1247
                               [1.56028502, 3.43916636, 0.78235384]),
1248
                              ([60, 15, 100, 1000, [1, 2, 3]],
1249
                               [0.67796492, 2.08982134, 0.55860785]),
1250
                              ([60, 15, 200, 500, [1, 2, 3]],
1251
                               [0.72377397, 2.42737471, 0.59083726]),
1252
                              ([60, 30, 100, 500, [1, 2, 3]],
1253
                               [0.68626344, 2.37530133, 0.60060124]),
1254
                              ([70, 15, 100, 500, [1, 2, 3]],
1255
                               [0.72760378, 2.18082305, 0.29141979])])
1256
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1257
        """Test mapping of E-field to a specified height.
1258
1259
        Parameters
1260
        ----------
1261
        in_args : list
1262
            List of input arguments
1263
        test_mapped : list
1264
            List of expected outputs
1265
1266
        """
1267
        mapped = self.apex_out.map_E_to_height(*in_args)
1268
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1269
        return
1270
1271
    @pytest.mark.parametrize('ev_flag, test_mapped',
1272
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1273
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1274
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1275
    @pytest.mark.parametrize('ivec', range(0, 5))
1276
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1277
                                             arr_shape, ivec):
1278
        """Test mapping of E-field/drift to a specified height with arrays.
1279
1280
        Parameters
1281
        ----------
1282
        ev_flag : str
1283
            Character flag specifying whether to run 'E' or 'V' methods
1284
        test_mapped : list
1285
            List of expected outputs
1286
        arr_shape : tuple
1287
            Shape of the expected output
1288
        ivec : int
1289
            Index of the expected output
1290
1291
        """
1292
        # Set the base input and output values
1293
        eshape = list(arr_shape)
1294
        eshape.insert(0, 3)
1295
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1296
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1297
1298
        # Update inputs for one vectorized value if this is a location input
1299
        if ivec < 4:
1300
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1301
1302
        # Get the mapped output
1303
        apex_method = getattr(self.apex_out,
1304
                              "map_{:s}_to_height".format(ev_flag))
1305
        mapped = apex_method(*in_args)
1306
1307
        # Test the results
1308
        for i, test_val in enumerate(test_mapped):
1309
            assert mapped[i].shape == arr_shape
1310
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1311
        return
1312
1313
    @pytest.mark.parametrize("in_args,test_mapped",
1314
                             [([60, 15, 100, 500, [1, 2, 3]],
1315
                               [0.81971957, 2.84512495, 0.69545001]),
1316
                              ([60, 15, 100, 500, [2, 3, 4]],
1317
                               [1.83027746, 4.14346436, 0.94764179]),
1318
                              ([60, 15, 100, 1000, [1, 2, 3]],
1319
                               [0.92457698, 3.14997661, 0.85135187]),
1320
                              ([60, 15, 200, 500, [1, 2, 3]],
1321
                               [0.80388262, 2.79321504, 0.68285158]),
1322
                              ([60, 30, 100, 500, [1, 2, 3]],
1323
                               [0.76141245, 2.87884673, 0.73655941]),
1324
                              ([70, 15, 100, 500, [1, 2, 3]],
1325
                               [0.84681866, 2.5925821, 0.34792655])])
1326
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1327
        """Test mapping of velocity to a specified height.
1328
1329
        Parameters
1330
        ----------
1331
        in_args : list
1332
            List of input arguments
1333
        test_mapped : list
1334
            List of expected outputs
1335
1336
        """
1337
        mapped = self.apex_out.map_V_to_height(*in_args)
1338
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1339
        return
1340
1341
1342
class TestApexBasevectorMethods(object):
1343
    """Test the Apex height base vector methods."""
1344
    def setup_method(self):
1345
        """Initialize all tests."""
1346
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1347
        self.lat = 60
1348
        self.lon = 15
1349
        self.height = 100
1350
        self.test_basevec = None
1351
1352
    def teardown_method(self):
1353
        """Clean up after each test."""
1354
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1355
1356
    def get_comparison_results(self, bv_coord, coords, precision):
1357
        """Get the base vector results using the hidden function for comparison.
1358
1359
        Parameters
1360
        ----------
1361
        bv_coord : str
1362
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1363
            or 'bvectors_apex'
1364
        coords : str
1365
            Expects one of 'geo', 'apex', or 'qd'
1366
        precision : float
1367
            Float specifiying precision
1368
1369
        """
1370
        if coords == "geo":
1371
            glat = self.lat
1372
            glon = self.lon
1373
        else:
1374
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1375
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1376
                                        precision=precision)
1377
1378
        if bv_coord == 'qd':
1379
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1380
        elif bv_coord == 'apex':
1381
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1382
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1383
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1384
        else:
1385
            # These are set results that need to be updated with IGRF
1386
            if coords == "geo":
1387
                self.test_basevec = (
1388
                    np.array([4.42368795e-05, 4.42368795e-05]),
1389
                    np.array([[0.01047826, 0.01047826],
1390
                              [0.33089194, 0.33089194],
1391
                              [-1.04941, -1.04941]]),
1392
                    np.array([5.3564698e-05, 5.3564698e-05]),
1393
                    np.array([[0.00865356, 0.00865356],
1394
                              [0.27327004, 0.27327004],
1395
                              [-0.8666646, -0.8666646]]))
1396
            elif coords == "apex":
1397
                self.test_basevec = (
1398
                    np.array([4.48672735e-05, 4.48672735e-05]),
1399
                    np.array([[-0.12510721, -0.12510721],
1400
                              [0.28945938, 0.28945938],
1401
                              [-1.1505738, -1.1505738]]),
1402
                    np.array([6.38577444e-05, 6.38577444e-05]),
1403
                    np.array([[-0.08790194, -0.08790194],
1404
                              [0.2033779, 0.2033779],
1405
                              [-0.808408, -0.808408]]))
1406
            else:
1407
                self.test_basevec = (
1408
                    np.array([4.46348578e-05, 4.46348578e-05]),
1409
                    np.array([[-0.12642345, -0.12642345],
1410
                              [0.29695055, 0.29695055],
1411
                              [-1.1517885, -1.1517885]]),
1412
                    np.array([6.38626285e-05, 6.38626285e-05]),
1413
                    np.array([[-0.08835986, -0.08835986],
1414
                              [0.20754464, 0.20754464],
1415
                              [-0.8050078, -0.8050078]]))
1416
1417
        return
1418
1419
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1420
    @pytest.mark.parametrize("coords,precision",
1421
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1422
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1423
        """Test the base vector calculations with scalars.
1424
1425
        Parameters
1426
        ----------
1427
        bv_coord : str
1428
            Name of the input coordinate system
1429
        coords : str
1430
            Name of the output coordinate system
1431
        precision : float
1432
            Level of run precision requested
1433
1434
        """
1435
        # Get the base vectors
1436
        base_method = getattr(self.apex_out,
1437
                              "basevectors_{:s}".format(bv_coord))
1438
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1439
                              precision=precision)
1440
        self.get_comparison_results(bv_coord, coords, precision)
1441
        if bv_coord == "apex":
1442
            basevec = list(basevec)
1443
            for i in range(4):
1444
                # Not able to compare indices 2, 3, 4, and 5
1445
                basevec.pop(2)
1446
1447
        # Test the results
1448
        for i, vec in enumerate(basevec):
1449
            np.testing.assert_allclose(vec, self.test_basevec[i])
1450
        return
1451
1452
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1453
    def test_basevectors_scalar_shape(self, bv_coord):
1454
        """Test the shape of the scalar output.
1455
1456
        Parameters
1457
        ----------
1458
        bv_coord : str
1459
            Name of the input coordinate system
1460
1461
        """
1462
        base_method = getattr(self.apex_out,
1463
                              "basevectors_{:s}".format(bv_coord))
1464
        basevec = base_method(self.lat, self.lon, self.height)
1465
1466
        for i, vec in enumerate(basevec):
1467
            if i < 2:
1468
                assert vec.shape == (2,)
1469
            else:
1470
                assert vec.shape == (3,)
1471
        return
1472
1473
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1474
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1475
    @pytest.mark.parametrize("ivec", range(3))
1476
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1477
        """Test the output shape for array inputs.
1478
1479
        Parameters
1480
        ----------
1481
        arr_shape : tuple
1482
            Expected output shape
1483
        bv_coord : str
1484
            Name of the input coordinate system
1485
        ivec : int
1486
            Index of the evaluated output value
1487
1488
        """
1489
        # Define the input arguments
1490
        in_args = [self.lat, self.lon, self.height]
1491
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1492
1493
        # Get the basevectors
1494
        base_method = getattr(self.apex_out,
1495
                              "basevectors_{:s}".format(bv_coord))
1496
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1497
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1498
        if bv_coord == "apex":
1499
            basevec = list(basevec)
1500
            for i in range(4):
1501
                # Not able to compare indices 2, 3, 4, and 5
1502
                basevec.pop(2)
1503
1504
        # Evaluate the shape and the values
1505
        for i, vec in enumerate(basevec):
1506
            test_shape = list(arr_shape)
1507
            test_shape.insert(0, 2 if i < 2 else 3)
1508
            assert vec.shape == tuple(test_shape)
1509
            assert np.all(self.test_basevec[i][0] == vec[0])
1510
            assert np.all(self.test_basevec[i][1] == vec[1])
1511
        return
1512
1513
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1514
    def test_bvectors_apex(self, coords):
1515
        """Test the bvectors_apex method.
1516
1517
        Parameters
1518
        ----------
1519
        coords : str
1520
            Name of the coordiante system
1521
1522
        """
1523
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1524
                   [self.height, self.height]]
1525
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1526
1527
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1528
                                              precision=1e-10)
1529
        for i, vec in enumerate(basevec):
1530
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1531
                                                 decimal=5)
1532
        return
1533
1534
    def test_basevectors_apex_extra_values(self):
1535
        """Test specific values in the apex base vector output."""
1536
        # Set the testing arrays
1537
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1538
                             np.array([0.939012, 0.073416, -0.07342]),
1539
                             np.array([0.055389, 1.004155, 0.257594]),
1540
                             np.array([0, 0, 1.065135])]
1541
1542
        # Get the desired output
1543
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1544
1545
        # Test the values not covered by `test_basevectors_scalar`
1546
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1547
            np.testing.assert_allclose(basevec[ibase],
1548
                                       self.test_basevec[itest], rtol=1e-4)
1549
        return
1550
1551
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1552
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1553
    def test_basevectors_apex_delta(self, lat, lon):
1554
        """Test that vectors are calculated correctly.
1555
1556
        Parameters
1557
        ----------
1558
        lat : int or float
1559
            Latitude in degrees N
1560
        lon : int or float
1561
            Longitude in degrees E
1562
1563
        """
1564
        # Get the apex base vectors and sort them for easy testing
1565
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1566
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1567
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1568
        gvec = [g1, g2, g3]
1569
        dvec = [d1, d2, d3]
1570
        evec = [e1, e2, e3]
1571
1572
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1573
            delta = 1 if idelta == jdelta else 0
1574
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1575
                                       delta, rtol=0, atol=1e-5)
1576
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1577
                                       delta, rtol=0, atol=1e-5)
1578
        return
1579
1580
    def test_basevectors_apex_invalid_scalar(self):
1581
        """Test warning and fill values for base vectors with bad inputs."""
1582
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1583
        invalid = np.full(shape=(3,), fill_value=np.nan)
1584
1585
        # Get the output and the warnings
1586
        with warnings.catch_warnings(record=True) as warn_rec:
1587
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1588
1589
        for i, bvec in enumerate(basevec):
1590
            if i < 2:
1591
                assert not np.allclose(bvec, invalid[:2])
1592
            else:
1593
                np.testing.assert_allclose(bvec, invalid)
1594
1595
        assert issubclass(warn_rec[-1].category, UserWarning)
1596
        assert 'set to NaN where' in str(warn_rec[-1].message)
1597
        return
1598
1599
1600
class TestApexGetMethods(object):
1601
    """Test the Apex `get` methods."""
1602
    def setup_method(self):
1603
        """Initialize all tests."""
1604
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1605
1606
    def teardown_method(self):
1607
        """Clean up after each test."""
1608
        del self.apex_out
1609
1610
    @pytest.mark.parametrize("alat, aheight",
1611
                             [(10, 507.409702543805),
1612
                              (60, 20313.026999999987),
1613
                              ([10, 60],
1614
                               [507.409702543805, 20313.026999999987]),
1615
                              ([[10], [60]],
1616
                               [[507.409702543805], [20313.026999999987]])])
1617
    def test_get_apex(self, alat, aheight):
1618
        """Test the apex height retrieval results.
1619
1620
        Parameters
1621
        ----------
1622
        alat : int or float
1623
            Apex latitude in degrees N
1624
        aheight : int or float
1625
            Apex height in km
1626
1627
        """
1628
        alt = self.apex_out.get_apex(alat)
1629
        np.testing.assert_allclose(alt, aheight)
1630
        return
1631
1632
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1633
                             [([80], [100], [300], 5.100682377815247e-05),
1634
                              ([80, 80], [100], [300],
1635
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1636
                              ([[80], [80]], [100], [300],
1637
                               [[5.100682377815247e-05],
1638
                                [5.100682377815247e-05]]),
1639
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1640
                               np.array([4.18657154e-05, 5.11118114e-05,
1641
                                         4.91969854e-05, 5.10519207e-05,
1642
                                         4.90054816e-05])),
1643
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1644
    def test_get_babs(self, glat, glon, height, test_bmag):
1645
        """Test the method to get the magnitude of the magnetic field.
1646
1647
        Parameters
1648
        ----------
1649
        glat : list
1650
            List of latitudes in degrees N
1651
        glon : list
1652
            List of longitudes in degrees E
1653
        height : list
1654
            List of heights in km
1655
        test_bmag : float
1656
            Expected B field magnitude
1657
1658
        """
1659
        bmag = self.apex_out.get_babs(glat, glon, height)
1660
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1661
        return
1662
1663
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1664
    def test_get_apex_with_invalid_lat(self, bad_lat):
1665
        """Test get methods raise ValueError for invalid latitudes.
1666
1667
        Parameters
1668
        ----------
1669
        bad_lat : int or float
1670
            Bad input latitude in degrees N
1671
1672
        """
1673
1674
        with pytest.raises(ValueError) as verr:
1675
            self.apex_out.get_apex(bad_lat)
1676
1677
        assert str(verr.value).find("must be in [-90, 90]") > 0
1678
        return
1679
1680
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1681
    def test_get_babs_with_invalid_lat(self, bad_lat):
1682
        """Test get methods raise ValueError for invalid latitudes.
1683
1684
        Parameters
1685
        ----------
1686
        bad_lat : int or float
1687
            Bad input latitude in degrees N
1688
1689
        """
1690
1691
        with pytest.raises(ValueError) as verr:
1692
            self.apex_out.get_babs(bad_lat, 15, 100)
1693
1694
        assert str(verr.value).find("must be in [-90, 90]") > 0
1695
        return
1696
1697
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1698
    def test_get_at_lat_boundary(self, bound_lat):
1699
        """Test get methods at the latitude boundary, with allowed excess.
1700
1701
        Parameters
1702
        ----------
1703
        bound_lat : int or float
1704
            Boundary input latitude in degrees N
1705
1706
        """
1707
        # Get a latitude just beyond the limit
1708
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1709
1710
        # Get the two outputs, slight tolerance outside of boundary allowed
1711
        bound_out = self.apex_out.get_apex(bound_lat)
1712
        excess_out = self.apex_out.get_apex(excess_lat)
1713
1714
        # Test the outputs
1715
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1716
        return
1717
1718
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1719
    def test_get_height_at_equator(self, apex_height):
1720
        """Test that `get_height` returns apex height at equator.
1721
1722
        Parameters
1723
        ----------
1724
        apex_height : float
1725
            Apex height
1726
1727
        """
1728
1729
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1730
        return
1731
1732
    @pytest.mark.parametrize("lat, height", [
1733
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1734
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1735
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1736
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1737
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1738
        (30, 657.2477500000014), (40, -871.8751821247979),
1739
        (50, -2499.1338178752017), (60, -4028.256749999999),
1740
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1741
    def test_get_height_along_fieldline(self, lat, height):
1742
        """Test that `get_height` returns expected height of field line.
1743
1744
        Parameters
1745
        ----------
1746
        lat : float
1747
            Input latitude
1748
        height : float
1749
            Output field-line height for line with apex of 3000 km
1750
1751
        """
1752
1753
        fheight = self.apex_out.get_height(lat, 3000.0)
1754
        assert abs(height - fheight) < 1.0e-7, \
1755
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1756
        return
1757