Passed
Push — develop ( 27451b...8436e3 )
by Angeline
01:48 queued 12s
created

test_Apex.TestApexInit.test_init_today()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    """Test class for the Apex class object."""
72
73
    def setup_method(self):
74
        """Initialize all tests."""
75
        self.apex_out = None
76
        self.test_date = dt.datetime.utcnow()
77
        self.test_refh = 0
78
        self.bad_file = 'foo/path/to/datafile.blah'
79
80
    def teardown_method(self):
81
        """Clean up after each test."""
82
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
83
84
    def eval_date(self):
85
        """Evaluate the times in self.test_date and self.apex_out."""
86
        if isinstance(self.test_date, dt.datetime) \
87
           or isinstance(self.test_date, dt.date):
88
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
89
90
        # Assert the times are the same on the order of tens of seconds.
91
        # Necessary to evaluate the current UTC
92
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
93
        return
94
95
    def eval_refh(self):
96
        """Evaluate the reference height in self.refh and self.apex_out."""
97
        eval_str = "".join(["expected reference height [",
98
                            "{:}] not equal to Apex ".format(self.test_refh),
99
                            "reference height ",
100
                            "[{:}]".format(self.apex_out.refh)])
101
        assert self.test_refh == self.apex_out.refh, eval_str
102
        return
103
104
    def test_init_defaults(self):
105
        """Test Apex class default initialization."""
106
        self.apex_out = apexpy.Apex()
107
        self.eval_date()
108
        self.eval_refh()
109
        return
110
111
    def test_init_today(self):
112
        """Test Apex class initialization with today's date."""
113
        self.apex_out = apexpy.Apex(date=self.test_date)
114
        self.eval_date()
115
        self.eval_refh()
116
        return
117
118
    @pytest.mark.parametrize("in_date",
119
                             [2015, 2015.5, dt.date(2015, 1, 1),
120
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
121
    def test_init_date(self, in_date):
122
        """Test Apex class with date initialization.
123
124
        Parameters
125
        ----------
126
        in_date : int, float, dt.date, or dt.datetime
127
            Input date in a variety of formats
128
129
        """
130
        self.test_date = in_date
131
        self.apex_out = apexpy.Apex(date=self.test_date)
132
        self.eval_date()
133
        self.eval_refh()
134
        return
135
136
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
137
    def test_set_epoch(self, new_date):
138
        """Test successful setting of Apex epoch after initialization.
139
140
        Parameters
141
        ----------
142
        new_date : int or float
143
            New date for the Apex class
144
145
        """
146
        # Evaluate the default initialization
147
        self.apex_out = apexpy.Apex()
148
        self.eval_date()
149
        self.eval_refh()
150
151
        # Update the epoch
152
        ref_apex = eval(self.apex_out.__repr__())
153
        self.apex_out.set_epoch(new_date)
154
        assert ref_apex != self.apex_out
155
        self.test_date = new_date
156
        self.eval_date()
157
        return
158
159
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
160
    def test_init_refh(self, in_refh):
161
        """Test Apex class with reference height initialization.
162
163
        Parameters
164
        ----------
165
        in_refh : float
166
            Input reference height in km
167
168
        """
169
        self.test_refh = in_refh
170
        self.apex_out = apexpy.Apex(refh=self.test_refh)
171
        self.eval_date()
172
        self.eval_refh()
173
        return
174
175
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
176
    def test_set_refh(self, new_refh):
177
        """Test the method used to set the reference height after the init.
178
179
        Parameters
180
        ----------
181
        new_refh : float
182
            Reference height in km
183
184
        """
185
        # Verify the defaults are set
186
        self.apex_out = apexpy.Apex(date=self.test_date)
187
        self.eval_date()
188
        self.eval_refh()
189
190
        # Update to a new reference height and test
191
        ref_apex = eval(self.apex_out.__repr__())
192
        self.apex_out.set_refh(new_refh)
193
194
        if self.test_refh == new_refh:
195
            assert ref_apex == self.apex_out
196
        else:
197
            assert ref_apex != self.apex_out
198
            self.test_refh = new_refh
199
        self.eval_refh()
200
        return
201
202
    def test_init_with_bad_datafile(self):
203
        """Test raises IOError with non-existent datafile input."""
204
        with pytest.raises(IOError) as oerr:
205
            apexpy.Apex(datafile=self.bad_file)
206
        assert str(oerr.value).startswith('Data file does not exist')
207
        return
208
209
    def test_init_with_bad_fortranlib(self):
210
        """Test raises IOError with non-existent datafile input."""
211
        with pytest.raises(IOError) as oerr:
212
            apexpy.Apex(fortranlib=self.bad_file)
213
        assert str(oerr.value).startswith('Fortran library does not exist')
214
        return
215
216
    def test_repr_eval(self):
217
        """Test the Apex.__repr__ results."""
218
        # Initialize the apex object
219
        self.apex_out = apexpy.Apex()
220
        self.eval_date()
221
        self.eval_refh()
222
223
        # Get and test the repr string
224
        out_str = self.apex_out.__repr__()
225
        assert out_str.find("apexpy.Apex(") == 0
226
227
        # Test the ability to re-create the apex object from the repr string
228
        new_apex = eval(out_str)
229
        assert new_apex == self.apex_out
230
        return
231
232
    def test_ne_other_class(self):
233
        """Test Apex class inequality to a different class."""
234
        self.apex_out = apexpy.Apex()
235
        self.eval_date()
236
        self.eval_refh()
237
238
        assert self.apex_out != self.test_date
239
        return
240
241
    def test_ne_missing_attr(self):
242
        """Test Apex class inequality when attributes are missing from one."""
243
        self.apex_out = apexpy.Apex()
244
        self.eval_date()
245
        self.eval_refh()
246
        ref_apex = eval(self.apex_out.__repr__())
247
        del ref_apex.RE
248
249
        assert ref_apex != self.apex_out
250
        assert self.apex_out != ref_apex
251
        return
252
253
    def test_eq_missing_attr(self):
254
        """Test Apex class equality when attributes are missing from both."""
255
        self.apex_out = apexpy.Apex()
256
        self.eval_date()
257
        self.eval_refh()
258
        ref_apex = eval(self.apex_out.__repr__())
259
        del ref_apex.RE, self.apex_out.RE
260
261
        assert ref_apex == self.apex_out
262
        return
263
264
    def test_str_eval(self):
265
        """Test the Apex.__str__ results."""
266
        # Initialize the apex object
267
        self.apex_out = apexpy.Apex()
268
        self.eval_date()
269
        self.eval_refh()
270
271
        # Get and test the printed string
272
        out_str = self.apex_out.__str__()
273
        assert out_str.find("Decimal year") > 0
274
        return
275
276
277
class TestApexMethod(object):
278
    """Test the Apex methods."""
279
    def setup_method(self):
280
        """Initialize all tests."""
281
        self.apex_out = apexpy.Apex(date=2000, refh=300)
282
        self.in_lat = 60
283
        self.in_lon = 15
284
        self.in_alt = 100
285
286
    def teardown_method(self):
287
        """Clean up after each test."""
288
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
289
290
    def get_input_args(self, method_name, precision=0.0):
291
        """Set the input arguments for the different Apex methods.
292
293
        Parameters
294
        ----------
295
        method_name : str
296
            Name of the Apex class method
297
        precision : float
298
            Value for the precision (default=0.0)
299
300
        Returns
301
        -------
302
        in_args : list
303
            List of the appropriate input arguments
304
305
        """
306
        in_args = [self.in_lat, self.in_lon, self.in_alt]
307
308
        # Add precision, if needed
309
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
310
                           "_apex2geo"]:
311
            in_args.append(precision)
312
313
        # Add a reference height, if needed
314
        if method_name in ["apxg2all"]:
315
            in_args.append(300)
316
317
        # Add a vector flag, if needed
318
        if method_name in ["apxg2all", "apxg2q"]:
319
            in_args.append(1)
320
321
        return in_args
322
323
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
324
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
325
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
326
                              ("_qd2geo", "apxq2g", slice(None)),
327
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
328
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
329
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
330
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
331
                                  lat, lon):
332
        """Tests Apex/fortran interface consistency for scalars.
333
334
        Parameters
335
        ----------
336
        apex_method : str
337
            Name of the Apex class method to test
338
        fortran_method : str
339
            Name of the Fortran function to test
340
        fslice : slice
341
            Slice used select the appropriate Fortran outputs
342
        lat : int or float
343
            Latitude in degrees N
344
        lon : int or float
345
            Longitude in degrees E
346
347
        """
348
        # Set the input coordinates
349
        self.in_lat = lat
350
        self.in_lon = lon
351
352
        # Get the Apex class method and the fortran function call
353
        apex_func = getattr(self.apex_out, apex_method)
354
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
355
356
        # Get the appropriate input arguments
357
        apex_args = self.get_input_args(apex_method)
358
        fortran_args = self.get_input_args(fortran_method)
359
360
        # Evaluate the equivalent function calls
361
        np.testing.assert_allclose(apex_func(*apex_args),
362
                                   fortran_func(*fortran_args)[fslice])
363
        return
364
365
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
366
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
367
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
368
                              ("_qd2geo", "apxq2g", slice(None)),
369
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
370
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
371
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
372
                                           (180, -180), (-180, 180),
373
                                           (-345, 15), (375, 15)])
374
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
375
                                        fslice, lat, lon1, lon2):
376
        """Tests Apex/fortran interface consistency for longitude rollover.
377
378
        Parameters
379
        ----------
380
        apex_method : str
381
            Name of the Apex class method to test
382
        fortran_method : str
383
            Name of the Fortran function to test
384
        fslice : slice
385
            Slice used select the appropriate Fortran outputs
386
        lat : int or float
387
            Latitude in degrees N
388
        lon1 : int or float
389
            Longitude in degrees E
390
        lon2 : int or float
391
            Equivalent longitude in degrees E
392
393
        """
394
        # Set the fixed input coordinate
395
        self.in_lat = lat
396
397
        # Get the Apex class method and the fortran function call
398
        apex_func = getattr(self.apex_out, apex_method)
399
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
400
401
        # Get the appropriate input arguments
402
        self.in_lon = lon1
403
        apex_args = self.get_input_args(apex_method)
404
405
        self.in_lon = lon2
406
        fortran_args = self.get_input_args(fortran_method)
407
408
        # Evaluate the equivalent function calls
409
        np.testing.assert_allclose(apex_func(*apex_args),
410
                                   fortran_func(*fortran_args)[fslice])
411
        return
412
413
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
414
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
415
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
416
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
417
                              ("_qd2geo", "apxq2g", slice(None)),
418
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
419
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
420
                                 fslice):
421
        """Tests Apex/fortran interface consistency for array input.
422
423
        Parameters
424
        ----------
425
        arr_shape : tuple
426
            Expected output shape
427
        apex_method : str
428
            Name of the Apex class method to test
429
        fortran_method : str
430
            Name of the Fortran function to test
431
        fslice : slice
432
            Slice used select the appropriate Fortran outputs
433
434
        """
435
        # Get the Apex class method and the fortran function call
436
        apex_func = getattr(self.apex_out, apex_method)
437
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
438
439
        # Set up the input arrays
440
        ref_lat = np.array([0, 30, 60, 90])
441
        ref_alt = np.array([100, 200, 300, 400])
442
        self.in_lat = ref_lat.reshape(arr_shape)
443
        self.in_alt = ref_alt.reshape(arr_shape)
444
        apex_args = self.get_input_args(apex_method)
445
446
        # Get the Apex class results
447
        aret = apex_func(*apex_args)
448
449
        # Get the fortran function results
450
        flats = list()
451
        flons = list()
452
453
        for i, lat in enumerate(ref_lat):
454
            self.in_lat = lat
455
            self.in_alt = ref_alt[i]
456
            fortran_args = self.get_input_args(fortran_method)
457
            fret = fortran_func(*fortran_args)[fslice]
458
            flats.append(fret[0])
459
            flons.append(fret[1])
460
461
        flats = np.array(flats)
462
        flons = np.array(flons)
463
464
        # Evaluate results
465
        try:
466
            # This returned value is array of floats
467
            np.testing.assert_allclose(aret[0].astype(float),
468
                                       flats.reshape(arr_shape).astype(float))
469
            np.testing.assert_allclose(aret[1].astype(float),
470
                                       flons.reshape(arr_shape).astype(float))
471
        except ValueError:
472
            # This returned value is array of arrays
473
            alats = aret[0].reshape((4,))
474
            alons = aret[1].reshape((4,))
475
            for i, flat in enumerate(flats):
476
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
477
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
478
479
        return
480
481
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
482
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
483
    def test_geo2apexall_scalar(self, lat, lon):
484
        """Test Apex/fortran geo2apexall interface consistency for scalars.
485
486
        Parameters
487
        ----------
488
        lat : int or float
489
            Latitude in degrees N
490
        long : int or float
491
            Longitude in degrees E
492
493
        """
494
        # Get the Apex and Fortran results
495
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
496
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
497
498
        # Evaluate each element in the results
499
        for aval, fval in zip(aret, fret):
500
            np.testing.assert_allclose(aval, fval)
501
502
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
503
    def test_geo2apexall_array(self, arr_shape):
504
        """Test Apex/fortran geo2apexall interface consistency for arrays.
505
506
        Parameters
507
        ----------
508
        arr_shape : tuple
509
            Expected output shape
510
511
        """
512
        # Set the input
513
        self.in_lat = np.array([0, 30, 60, 90])
514
        self.in_alt = np.array([100, 200, 300, 400])
515
516
        # Get the Apex class results
517
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
518
                                          self.in_lon,
519
                                          self.in_alt.reshape(arr_shape))
520
521
        # For each lat/alt pair, get the Fortran results
522
        fret = list()
523
        for i, lat in enumerate(self.in_lat):
524
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
525
                                                    self.in_alt[i], 300, 1))
526
527
        # Cycle through all returned values
528
        for i, ret in enumerate(aret):
529
            try:
530
                # This returned value is array of floats
531
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
532
                                      fret[3][i]]).reshape(arr_shape)
533
                np.testing.assert_allclose(ret.astype(float),
534
                                           fret_test.astype(float))
535
            except ValueError:
536
                # This returned value is array of arrays
537
                ret = ret.reshape((4,))
538
                for j, single_fret in enumerate(fret):
539
                    np.testing.assert_allclose(ret[j], single_fret[i])
540
        return
541
542
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
543
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
544
    def test_convert_consistency(self, in_coord, out_coord):
545
        """Test the self-consistency of the Apex convert method.
546
547
        Parameters
548
        ----------
549
        in_coord : str
550
            Input coordinate system
551
        out_coord : str
552
            Output coordinate system
553
554
        """
555
        if in_coord == out_coord:
556
            pytest.skip("Test not needed for same src and dest coordinates")
557
558
        # Define the method name
559
        method_name = "2".join([in_coord, out_coord])
560
561
        # Get the method and method inputs
562
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
563
        apex_args = self.get_input_args(method_name)
564
        apex_method = getattr(self.apex_out, method_name)
565
566
        # Define the slice needed to get equivalent output from the named method
567
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
568
569
        # Get output using convert and named method
570
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
571
                                            out_coord, **convert_kwargs)
572
        method_out = apex_method(*apex_args)[mslice]
573
574
        # Compare both outputs, should be identical
575
        np.testing.assert_allclose(convert_out, method_out)
576
        return
577
578
    @pytest.mark.parametrize("bound_lat", [90, -90])
579
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
580
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
581
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
582
        """Test the conversion at the latitude boundary, with allowed excess.
583
584
        Parameters
585
        ----------
586
        bound_lat : int or float
587
            Boundary latitude in degrees N
588
        in_coord : str
589
            Input coordinate system
590
        out_coord : str
591
            Output coordinate system
592
593
        """
594
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
595
596
        # Get the two outputs, slight tolerance outside of boundary allowed
597
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
598
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
599
600
        # Test the outputs
601
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
602
        return
603
604
    def test_convert_qd2apex_at_equator(self):
605
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
606
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
607
                                       height=320.0)
608
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
609
                                          dest='apex', height=320.0)
610
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
611
        return
612
613
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
614
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
615
    def test_convert_withnan(self, src, dest):
616
        """Test Apex.convert success with NaN input.
617
618
        Parameters
619
        ----------
620
        src : str
621
            Input coordinate system
622
        dest : str
623
            Output coordinate system
624
625
        """
626
        if src == dest:
627
            pytest.skip("Test not needed for same src and dest coordinates")
628
629
        num_nans = 5
630
        in_loc = np.arange(0, 10, dtype=float)
631
        in_loc[:num_nans] = np.nan
632
633
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
634
635
        for out in out_loc:
636
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
637
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
638
639
        return
640
641
    @pytest.mark.parametrize("bad_lat", [91, -91])
642
    def test_convert_invalid_lat(self, bad_lat):
643
        """Test convert raises ValueError for invalid latitudes.
644
645
        Parameters
646
        ----------
647
        bad_lat : int or float
648
            Latitude ouside the supported range in degrees N
649
650
        """
651
652
        with pytest.raises(ValueError) as verr:
653
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
654
655
        assert str(verr.value).find("must be in [-90, 90]") > 0
656
        return
657
658
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
659
                                        ("geo", "mlt")])
660
    def test_convert_invalid_transformation(self, coords):
661
        """Test raises NotImplementedError for bad coordinates.
662
663
        Parameters
664
        ----------
665
        coords : tuple
666
            Tuple specifying the input and output coordinate systems
667
668
        """
669
        if "mlt" in coords:
670
            estr = "datetime must be given for MLT calculations"
671
        else:
672
            estr = "Unknown coordinate transformation"
673
674
        with pytest.raises(ValueError) as verr:
675
            self.apex_out.convert(0, 0, *coords)
676
677
        assert str(verr).find(estr) >= 0
678
        return
679
680
    @pytest.mark.parametrize("method_name, out_comp",
681
                             [("geo2apex",
682
                               (55.94841766357422, 94.10684204101562)),
683
                              ("apex2geo",
684
                               (51.476322174072266, -66.22817993164062,
685
                                5.727287771151168e-06)),
686
                              ("geo2qd",
687
                               (56.531288146972656, 94.10684204101562)),
688
                              ("apex2qd", (60.498401178276744, 15.0)),
689
                              ("qd2apex", (59.49138097045895, 15.0))])
690
    def test_method_scalar_input(self, method_name, out_comp):
691
        """Test the user method against set values with scalars.
692
693
        Parameters
694
        ----------
695
        method_name : str
696
            Apex class method to be tested
697
        out_comp : tuple of floats
698
            Expected output values
699
700
        """
701
        # Get the desired methods
702
        user_method = getattr(self.apex_out, method_name)
703
704
        # Get the user output
705
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
706
707
        # Evaluate the user output
708
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
709
710
        for out_val in user_out:
711
            assert np.asarray(out_val).shape == (), "output is not a scalar"
712
        return
713
714
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
715
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
716
    @pytest.mark.parametrize("method_args, out_shape",
717
                             [([[60, 60], 15, 100], (2,)),
718
                              ([60, [15, 15], 100], (2,)),
719
                              ([60, 15, [100, 100]], (2,)),
720
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
721
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
722
                                    out_shape):
723
        """Test the user method with inputs that require some broadcasting.
724
725
        Parameters
726
        ----------
727
        in_coord : str
728
            Input coordiante system
729
        out_coord : str
730
            Output coordiante system
731
        method_args : list
732
            List of input arguments
733
        out_shape : tuple
734
            Expected shape of output values
735
736
        """
737
        if in_coord == out_coord:
738
            pytest.skip("Test not needed for same src and dest coordinates")
739
740
        # Get the desired methods
741
        method_name = "2".join([in_coord, out_coord])
742
        user_method = getattr(self.apex_out, method_name)
743
744
        # Get the user output
745
        user_out = user_method(*method_args)
746
747
        # Evaluate the user output
748
        for out_val in user_out:
749
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
750
            assert out_val.shape == out_shape
751
        return
752
753
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
754
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
755
    @pytest.mark.parametrize("bad_lat", [91, -91])
756
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
757
        """Test convert raises ValueError for invalid latitudes.
758
759
        Parameters
760
        ----------
761
        in_coord : str
762
            Input coordiante system
763
        out_coord : str
764
            Output coordiante system
765
        bad_lat : int
766
            Latitude in degrees N that is out of bounds
767
768
        """
769
        if in_coord == out_coord:
770
            pytest.skip("Test not needed for same src and dest coordinates")
771
772
        # Get the desired methods
773
        method_name = "2".join([in_coord, out_coord])
774
        user_method = getattr(self.apex_out, method_name)
775
776
        with pytest.raises(ValueError) as verr:
777
            user_method(bad_lat, 15, 100)
778
779
        assert str(verr.value).find("must be in [-90, 90]") > 0
780
        return
781
782
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
783
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
784
    @pytest.mark.parametrize("bound_lat", [90, -90])
785
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
786
        """Test user methods at the latitude boundary, with allowed excess.
787
788
        Parameters
789
        ----------
790
        in_coord : str
791
            Input coordiante system
792
        out_coord : str
793
            Output coordiante system
794
        bad_lat : int
795
            Latitude in degrees N that is at the limits of the boundary
796
797
        """
798
        if in_coord == out_coord:
799
            pytest.skip("Test not needed for same src and dest coordinates")
800
801
        # Get the desired methods
802
        method_name = "2".join([in_coord, out_coord])
803
        user_method = getattr(self.apex_out, method_name)
804
805
        # Get a latitude just beyond the limit
806
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
807
808
        # Get the two outputs, slight tolerance outside of boundary allowed
809
        bound_out = user_method(bound_lat, 0, 100)
810
        excess_out = user_method(excess_lat, 0, 100)
811
812
        # Test the outputs
813
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
814
        return
815
816
    def test_geo2apex_undefined_warning(self):
817
        """Test geo2apex warning and fill values for an undefined location."""
818
819
        # Update the apex object
820
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
821
822
        # Get the output and the warnings
823
        with warnings.catch_warnings(record=True) as warn_rec:
824
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
825
826
        assert np.isnan(user_lat)
827
        assert np.isfinite(user_lon)
828
        assert len(warn_rec) == 1
829
        assert issubclass(warn_rec[-1].category, UserWarning)
830
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
831
        return
832
833
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
834
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
835
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
836
        """Test quasi-dipole success with a height close to the reference.
837
838
        Parameters
839
        ----------
840
        method_name : str
841
            Apex class method name to be tested
842
        delta_h : float
843
            tolerance for height in km
844
845
        """
846
        qd_method = getattr(self.apex_out, method_name)
847
        in_args = [0, 15, self.apex_out.refh + delta_h]
848
        out_coords = qd_method(*in_args)
849
850
        for i, out_val in enumerate(out_coords):
851
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
852
        return
853
854
    @pytest.mark.parametrize("method_name, hinc, msg",
855
                             [("apex2qd", 1.0, "is > apex height"),
856
                              ("qd2apex", -1.0, "is < reference height")])
857
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
858
        """Quasi-dipole raises ApexHeightError when height above reference.
859
860
        Parameters
861
        ----------
862
        method_name : str
863
            Apex class method name to be tested
864
        hinc : float
865
            Height increment in km
866
        msg : str
867
            Expected output message
868
869
        """
870
        qd_method = getattr(self.apex_out, method_name)
871
872
        with pytest.raises(apexpy.ApexHeightError) as aerr:
873
            qd_method(0, 15, self.apex_out.refh + hinc)
874
875
        assert str(aerr).find(msg) > 0
876
        return
877
878
879
class TestApexMLTMethods(object):
880
    """Test the Apex Magnetic Local Time (MLT) methods."""
881
    def setup_method(self):
882
        """Initialize all tests."""
883
        self.apex_out = apexpy.Apex(date=2000, refh=300)
884
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
885
886
    def teardown_method(self):
887
        """Clean up after each test."""
888
        del self.apex_out, self.in_time
889
890
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
891
    def test_convert_to_mlt(self, in_coord):
892
        """Test the conversions to MLT using Apex convert.
893
894
        Parameters
895
        ----------
896
        in_coord : str
897
            Input coordinate system
898
899
        """
900
901
        # Get the magnetic longitude from the appropriate method
902
        if in_coord == "geo":
903
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
904
            mlon = apex_method(60, 15, 100)[1]
905
        else:
906
            mlon = 15
907
908
        # Get the output MLT values
909
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
910
                                            height=100, ssheight=2e5,
911
                                            datetime=self.in_time)[1]
912
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
913
914
        # Test the outputs
915
        np.testing.assert_allclose(convert_mlt, method_mlt)
916
        return
917
918
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
919
    def test_convert_mlt_to_lon(self, out_coord):
920
        """Test the conversions from MLT using Apex convert.
921
922
        Parameters
923
        ----------
924
        out_coord : str
925
            Output coordinate system
926
927
        """
928
        # Get the output longitudes
929
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
930
                                            height=100, ssheight=2e5,
931
                                            datetime=self.in_time,
932
                                            precision=1e-2)
933
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
934
935
        if out_coord == "geo":
936
            method_out = self.apex_out.apex2geo(60, mlon, 100,
937
                                                precision=1e-2)[:-1]
938
        elif out_coord == "qd":
939
            method_out = self.apex_out.apex2qd(60, mlon, 100)
940
        else:
941
            method_out = (60, mlon)
942
943
        # Evaluate the outputs
944
        np.testing.assert_allclose(convert_out, method_out)
945
        return
946
947
    def test_convert_geo2mlt_nodate(self):
948
        """Test convert from geo to MLT raises ValueError with no datetime."""
949
        with pytest.raises(ValueError):
950
            self.apex_out.convert(60, 15, 'geo', 'mlt')
951
        return
952
953
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
954
                             [({}, 23.019629923502603),
955
                              ({"ssheight": 100000}, 23.026712036132814)])
956
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
957
        """Test mlon2mlt with scalar inputs.
958
959
        Parameters
960
        ----------
961
        mlon_kwargs : dict
962
            Input kwargs
963
        test_mlt : float
964
            Output MLT in hours
965
966
        """
967
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
968
969
        np.testing.assert_allclose(mlt, test_mlt)
970
        assert np.asarray(mlt).shape == ()
971
        return
972
973
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
974
                             [({}, 14.705535888671875),
975
                              ({"ssheight": 100000}, 14.599319458007812)])
976
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
977
        """Test mlt2mlon with scalar inputs.
978
979
        Parameters
980
        ----------
981
        mlt_kwargs : dict
982
            Input kwargs
983
        test_mlon : float
984
            Output longitude in degrees E
985
986
        """
987
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
988
989
        np.testing.assert_allclose(mlon, test_mlon)
990
        assert np.asarray(mlon).shape == ()
991
        return
992
993
    @pytest.mark.parametrize("mlon,test_mlt",
994
                             [([0, 180], [23.019261, 11.019261]),
995
                              (np.array([0, 180]), [23.019261, 11.019261]),
996
                              (np.array([[0], [180]]),
997
                               np.array([[23.019261], [11.019261]])),
998
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
999
                                                      [23.019261, 11.019261]]),
1000
                              (range(0, 361, 30),
1001
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
1002
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
1003
                                19.01963, 21.01963, 23.01963])])
1004
    def test_mlon2mlt_array(self, mlon, test_mlt):
1005
        """Test mlon2mlt with array inputs.
1006
1007
        Parameters
1008
        ----------
1009
        mlon : array-like
1010
            Input longitudes in degrees E
1011
        test_mlt : float
1012
            Output MLT in hours
1013
1014
        """
1015
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1016
1017
        assert mlt.shape == np.asarray(test_mlt).shape
1018
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1019
        return
1020
1021
    @pytest.mark.parametrize("mlt,test_mlon",
1022
                             [([0, 12], [14.705551, 194.705551]),
1023
                              (np.array([0, 12]), [14.705551, 194.705551]),
1024
                              (np.array([[0], [12]]),
1025
                               np.array([[14.705551], [194.705551]])),
1026
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1027
                                                    [14.705551, 194.705551]]),
1028
                              (range(0, 25, 2),
1029
                               [14.705551, 44.705551, 74.705551, 104.705551,
1030
                                134.705551, 164.705551, 194.705551, 224.705551,
1031
                                254.705551, 284.705551, 314.705551, 344.705551,
1032
                                14.705551])])
1033
    def test_mlt2mlon_array(self, mlt, test_mlon):
1034
        """Test mlt2mlon with array inputs.
1035
1036
        Parameters
1037
        ----------
1038
        mlt : array-like
1039
            Input MLT in hours
1040
        test_mlon : float
1041
            Output longitude in degrees E
1042
1043
        """
1044
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1045
1046
        assert mlon.shape == np.asarray(test_mlon).shape
1047
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1048
        return
1049
1050
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1051
    def test_mlon2mlt_diffdates(self, method_name):
1052
        """Test that MLT varies with universal time.
1053
1054
        Parameters
1055
        ----------
1056
        method_name : str
1057
            Name of Apex class method to be tested
1058
1059
        """
1060
        apex_method = getattr(self.apex_out, method_name)
1061
        mlt1 = apex_method(0, self.in_time)
1062
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1063
1064
        assert mlt1 != mlt2
1065
        return
1066
1067
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1068
    def test_mlon2mlt_offset(self, mlt_offset):
1069
        """Test the time wrapping logic for the MLT.
1070
1071
        Parameters
1072
        ----------
1073
        mlt_offset : float
1074
            MLT offset in hours
1075
1076
        """
1077
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1078
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1079
                                      self.in_time) + mlt_offset
1080
1081
        np.testing.assert_allclose(mlt1, mlt2)
1082
        return
1083
1084
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1085
    def test_mlt2mlon_offset(self, mlon_offset):
1086
        """Test the time wrapping logic for the magnetic longitude.
1087
1088
        Parameters
1089
        ----------
1090
        mlt_offset : float
1091
            MLT offset in hours
1092
1093
        """
1094
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1095
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1096
                                       self.in_time) - mlon_offset
1097
1098
        np.testing.assert_allclose(mlon1, mlon2)
1099
        return
1100
1101
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1102
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1103
    def test_convert_and_return(self, order, start_val):
1104
        """Test the conversion to magnetic longitude or MLT and back again.
1105
1106
        Parameters
1107
        ----------
1108
        order : list
1109
            List of strings specifying the order to run functions
1110
        start_val : int or float
1111
            Input value
1112
1113
        """
1114
        first_method = getattr(self.apex_out, "2".join(order))
1115
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1116
1117
        middle_val = first_method(start_val, self.in_time)
1118
        end_val = second_method(middle_val, self.in_time)
1119
1120
        np.testing.assert_allclose(start_val, end_val)
1121
        return
1122
1123
1124
class TestApexMapMethods(object):
1125
    """Test the Apex height mapping methods."""
1126
    def setup_method(self):
1127
        """Initialize all tests."""
1128
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1129
1130
    def teardown_method(self):
1131
        """Clean up after each test."""
1132
        del self.apex_out
1133
1134
    @pytest.mark.parametrize("in_args,test_mapped",
1135
                             [([60, 15, 100, 10000],
1136
                               [31.841466903686523, 17.916635513305664,
1137
                                1.7075473124350538e-6]),
1138
                              ([30, 170, 100, 500, False, 1e-2],
1139
                               [25.727270126342773, 169.60546875,
1140
                                0.00017573432705830783]),
1141
                              ([60, 15, 100, 10000, True],
1142
                               [-25.424888610839844, 27.310426712036133,
1143
                                1.2074182222931995e-6]),
1144
                              ([30, 170, 100, 500, True, 1e-2],
1145
                               [-13.76642894744873, 164.24259948730469,
1146
                                0.00056820799363777041])])
1147
    def test_map_to_height(self, in_args, test_mapped):
1148
        """Test the map_to_height function.
1149
1150
        Parameters
1151
        ----------
1152
        in_args : list
1153
            List of input arguments
1154
        test_mapped : list
1155
            List of expected outputs
1156
1157
        """
1158
        mapped = self.apex_out.map_to_height(*in_args)
1159
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1160
        return
1161
1162
    def test_map_to_height_same_height(self):
1163
        """Test the map_to_height function when mapping to same height."""
1164
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1165
                                             precision=1e-10)
1166
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1167
                                   rtol=1e-5, atol=1e-5)
1168
        return
1169
1170
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1171
    @pytest.mark.parametrize('ivec', range(0, 4))
1172
    def test_map_to_height_array_location(self, arr_shape, ivec):
1173
        """Test map_to_height with array input.
1174
1175
        Parameters
1176
        ----------
1177
        arr_shape : tuple
1178
            Expected array shape
1179
        ivec : int
1180
            Input argument index for vectorized input
1181
1182
        """
1183
        # Set the base input and output values
1184
        in_args = [60, 15, 100, 100]
1185
        test_mapped = [60, 15.00000381, 0.0]
1186
1187
        # Update inputs for one vectorized value
1188
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1189
1190
        # Calculate and test function
1191
        mapped = self.apex_out.map_to_height(*in_args)
1192
        for i, test_val in enumerate(test_mapped):
1193
            assert mapped[i].shape == arr_shape
1194
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1195
                                       atol=1e-5)
1196
        return
1197
1198
    @pytest.mark.parametrize("method_name,in_args",
1199
                             [("map_to_height", [0, 15, 100, 10000]),
1200
                              ("map_E_to_height",
1201
                               [0, 15, 100, 10000, [1, 2, 3]]),
1202
                              ("map_V_to_height",
1203
                               [0, 15, 100, 10000, [1, 2, 3]])])
1204
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1205
        """Test map_to_height raises ApexHeightError.
1206
1207
        Parameters
1208
        ----------
1209
        method_name : str
1210
            Name of the Apex class method to test
1211
        in_args : list
1212
            List of input arguments
1213
1214
        """
1215
        apex_method = getattr(self.apex_out, method_name)
1216
1217
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1218
            apex_method(*in_args)
1219
1220
        assert aerr.match("is > apex height")
1221
        return
1222
1223
    @pytest.mark.parametrize("method_name",
1224
                             ["map_E_to_height", "map_V_to_height"])
1225
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1226
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1227
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1228
        """Test height mapping of E/V with baddly shaped input raises Error.
1229
1230
        Parameters
1231
        ----------
1232
        method_name : str
1233
            Name of the Apex class method to test
1234
        ev_input : list
1235
            E/V input arguments
1236
1237
        """
1238
        apex_method = getattr(self.apex_out, method_name)
1239
        in_args = [60, 15, 100, 500, ev_input]
1240
        with pytest.raises(ValueError) as verr:
1241
            apex_method(*in_args)
1242
1243
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1244
        return
1245
1246
    def test_mapping_EV_bad_flag(self):
1247
        """Test _map_EV_to_height raises error for bad data type flag."""
1248
        with pytest.raises(ValueError) as verr:
1249
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1250
1251
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1252
        return
1253
1254
    @pytest.mark.parametrize("in_args,test_mapped",
1255
                             [([60, 15, 100, 500, [1, 2, 3]],
1256
                               [0.71152183, 2.35624876, 0.57260784]),
1257
                              ([60, 15, 100, 500, [2, 3, 4]],
1258
                               [1.56028502, 3.43916636, 0.78235384]),
1259
                              ([60, 15, 100, 1000, [1, 2, 3]],
1260
                               [0.67796492, 2.08982134, 0.55860785]),
1261
                              ([60, 15, 200, 500, [1, 2, 3]],
1262
                               [0.72377397, 2.42737471, 0.59083726]),
1263
                              ([60, 30, 100, 500, [1, 2, 3]],
1264
                               [0.68626344, 2.37530133, 0.60060124]),
1265
                              ([70, 15, 100, 500, [1, 2, 3]],
1266
                               [0.72760378, 2.18082305, 0.29141979])])
1267
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1268
        """Test mapping of E-field to a specified height.
1269
1270
        Parameters
1271
        ----------
1272
        in_args : list
1273
            List of input arguments
1274
        test_mapped : list
1275
            List of expected outputs
1276
1277
        """
1278
        mapped = self.apex_out.map_E_to_height(*in_args)
1279
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1280
        return
1281
1282
    @pytest.mark.parametrize('ev_flag, test_mapped',
1283
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1284
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1285
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1286
    @pytest.mark.parametrize('ivec', range(0, 5))
1287
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1288
                                             arr_shape, ivec):
1289
        """Test mapping of E-field/drift to a specified height with arrays.
1290
1291
        Parameters
1292
        ----------
1293
        ev_flag : str
1294
            Character flag specifying whether to run 'E' or 'V' methods
1295
        test_mapped : list
1296
            List of expected outputs
1297
        arr_shape : tuple
1298
            Shape of the expected output
1299
        ivec : int
1300
            Index of the expected output
1301
1302
        """
1303
        # Set the base input and output values
1304
        eshape = list(arr_shape)
1305
        eshape.insert(0, 3)
1306
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1307
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1308
1309
        # Update inputs for one vectorized value if this is a location input
1310
        if ivec < 4:
1311
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1312
1313
        # Get the mapped output
1314
        apex_method = getattr(self.apex_out,
1315
                              "map_{:s}_to_height".format(ev_flag))
1316
        mapped = apex_method(*in_args)
1317
1318
        # Test the results
1319
        for i, test_val in enumerate(test_mapped):
1320
            assert mapped[i].shape == arr_shape
1321
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1322
        return
1323
1324
    @pytest.mark.parametrize("in_args,test_mapped",
1325
                             [([60, 15, 100, 500, [1, 2, 3]],
1326
                               [0.81971957, 2.84512495, 0.69545001]),
1327
                              ([60, 15, 100, 500, [2, 3, 4]],
1328
                               [1.83027746, 4.14346436, 0.94764179]),
1329
                              ([60, 15, 100, 1000, [1, 2, 3]],
1330
                               [0.92457698, 3.14997661, 0.85135187]),
1331
                              ([60, 15, 200, 500, [1, 2, 3]],
1332
                               [0.80388262, 2.79321504, 0.68285158]),
1333
                              ([60, 30, 100, 500, [1, 2, 3]],
1334
                               [0.76141245, 2.87884673, 0.73655941]),
1335
                              ([70, 15, 100, 500, [1, 2, 3]],
1336
                               [0.84681866, 2.5925821, 0.34792655])])
1337
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1338
        """Test mapping of velocity to a specified height.
1339
1340
        Parameters
1341
        ----------
1342
        in_args : list
1343
            List of input arguments
1344
        test_mapped : list
1345
            List of expected outputs
1346
1347
        """
1348
        mapped = self.apex_out.map_V_to_height(*in_args)
1349
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1350
        return
1351
1352
1353
class TestApexBasevectorMethods(object):
1354
    """Test the Apex height base vector methods."""
1355
    def setup_method(self):
1356
        """Initialize all tests."""
1357
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1358
        self.lat = 60
1359
        self.lon = 15
1360
        self.height = 100
1361
        self.test_basevec = None
1362
1363
    def teardown_method(self):
1364
        """Clean up after each test."""
1365
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1366
1367
    def get_comparison_results(self, bv_coord, coords, precision):
1368
        """Get the base vector results using the hidden function for comparison.
1369
1370
        Parameters
1371
        ----------
1372
        bv_coord : str
1373
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1374
            or 'bvectors_apex'
1375
        coords : str
1376
            Expects one of 'geo', 'apex', or 'qd'
1377
        precision : float
1378
            Float specifiying precision
1379
1380
        """
1381
        if coords == "geo":
1382
            glat = self.lat
1383
            glon = self.lon
1384
        else:
1385
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1386
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1387
                                        precision=precision)
1388
1389
        if bv_coord == 'qd':
1390
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1391
        elif bv_coord == 'apex':
1392
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1393
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1394
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1395
        else:
1396
            # These are set results that need to be updated with IGRF
1397
            if coords == "geo":
1398
                self.test_basevec = (
1399
                    np.array([4.42368795e-05, 4.42368795e-05]),
1400
                    np.array([[0.01047826, 0.01047826],
1401
                              [0.33089194, 0.33089194],
1402
                              [-1.04941, -1.04941]]),
1403
                    np.array([5.3564698e-05, 5.3564698e-05]),
1404
                    np.array([[0.00865356, 0.00865356],
1405
                              [0.27327004, 0.27327004],
1406
                              [-0.8666646, -0.8666646]]))
1407
            elif coords == "apex":
1408
                self.test_basevec = (
1409
                    np.array([4.48672735e-05, 4.48672735e-05]),
1410
                    np.array([[-0.12510721, -0.12510721],
1411
                              [0.28945938, 0.28945938],
1412
                              [-1.1505738, -1.1505738]]),
1413
                    np.array([6.38577444e-05, 6.38577444e-05]),
1414
                    np.array([[-0.08790194, -0.08790194],
1415
                              [0.2033779, 0.2033779],
1416
                              [-0.808408, -0.808408]]))
1417
            else:
1418
                self.test_basevec = (
1419
                    np.array([4.46348578e-05, 4.46348578e-05]),
1420
                    np.array([[-0.12642345, -0.12642345],
1421
                              [0.29695055, 0.29695055],
1422
                              [-1.1517885, -1.1517885]]),
1423
                    np.array([6.38626285e-05, 6.38626285e-05]),
1424
                    np.array([[-0.08835986, -0.08835986],
1425
                              [0.20754464, 0.20754464],
1426
                              [-0.8050078, -0.8050078]]))
1427
1428
        return
1429
1430
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1431
    @pytest.mark.parametrize("coords,precision",
1432
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1433
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1434
        """Test the base vector calculations with scalars.
1435
1436
        Parameters
1437
        ----------
1438
        bv_coord : str
1439
            Name of the input coordinate system
1440
        coords : str
1441
            Name of the output coordinate system
1442
        precision : float
1443
            Level of run precision requested
1444
1445
        """
1446
        # Get the base vectors
1447
        base_method = getattr(self.apex_out,
1448
                              "basevectors_{:s}".format(bv_coord))
1449
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1450
                              precision=precision)
1451
        self.get_comparison_results(bv_coord, coords, precision)
1452
        if bv_coord == "apex":
1453
            basevec = list(basevec)
1454
            for i in range(4):
1455
                # Not able to compare indices 2, 3, 4, and 5
1456
                basevec.pop(2)
1457
1458
        # Test the results
1459
        for i, vec in enumerate(basevec):
1460
            np.testing.assert_allclose(vec, self.test_basevec[i])
1461
        return
1462
1463
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1464
    def test_basevectors_scalar_shape(self, bv_coord):
1465
        """Test the shape of the scalar output.
1466
1467
        Parameters
1468
        ----------
1469
        bv_coord : str
1470
            Name of the input coordinate system
1471
1472
        """
1473
        base_method = getattr(self.apex_out,
1474
                              "basevectors_{:s}".format(bv_coord))
1475
        basevec = base_method(self.lat, self.lon, self.height)
1476
1477
        for i, vec in enumerate(basevec):
1478
            if i < 2:
1479
                assert vec.shape == (2,)
1480
            else:
1481
                assert vec.shape == (3,)
1482
        return
1483
1484
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1485
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1486
    @pytest.mark.parametrize("ivec", range(3))
1487
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1488
        """Test the output shape for array inputs.
1489
1490
        Parameters
1491
        ----------
1492
        arr_shape : tuple
1493
            Expected output shape
1494
        bv_coord : str
1495
            Name of the input coordinate system
1496
        ivec : int
1497
            Index of the evaluated output value
1498
1499
        """
1500
        # Define the input arguments
1501
        in_args = [self.lat, self.lon, self.height]
1502
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1503
1504
        # Get the basevectors
1505
        base_method = getattr(self.apex_out,
1506
                              "basevectors_{:s}".format(bv_coord))
1507
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1508
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1509
        if bv_coord == "apex":
1510
            basevec = list(basevec)
1511
            for i in range(4):
1512
                # Not able to compare indices 2, 3, 4, and 5
1513
                basevec.pop(2)
1514
1515
        # Evaluate the shape and the values
1516
        for i, vec in enumerate(basevec):
1517
            test_shape = list(arr_shape)
1518
            test_shape.insert(0, 2 if i < 2 else 3)
1519
            assert vec.shape == tuple(test_shape)
1520
            assert np.all(self.test_basevec[i][0] == vec[0])
1521
            assert np.all(self.test_basevec[i][1] == vec[1])
1522
        return
1523
1524
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1525
    def test_bvectors_apex(self, coords):
1526
        """Test the bvectors_apex method.
1527
1528
        Parameters
1529
        ----------
1530
        coords : str
1531
            Name of the coordiante system
1532
1533
        """
1534
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1535
                   [self.height, self.height]]
1536
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1537
1538
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1539
                                              precision=1e-10)
1540
        for i, vec in enumerate(basevec):
1541
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1542
                                                 decimal=5)
1543
        return
1544
1545
    def test_basevectors_apex_extra_values(self):
1546
        """Test specific values in the apex base vector output."""
1547
        # Set the testing arrays
1548
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1549
                             np.array([0.939012, 0.073416, -0.07342]),
1550
                             np.array([0.055389, 1.004155, 0.257594]),
1551
                             np.array([0, 0, 1.065135])]
1552
1553
        # Get the desired output
1554
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1555
1556
        # Test the values not covered by `test_basevectors_scalar`
1557
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1558
            np.testing.assert_allclose(basevec[ibase],
1559
                                       self.test_basevec[itest], rtol=1e-4)
1560
        return
1561
1562
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1563
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1564
    def test_basevectors_apex_delta(self, lat, lon):
1565
        """Test that vectors are calculated correctly.
1566
1567
        Parameters
1568
        ----------
1569
        lat : int or float
1570
            Latitude in degrees N
1571
        lon : int or float
1572
            Longitude in degrees E
1573
1574
        """
1575
        # Get the apex base vectors and sort them for easy testing
1576
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1577
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1578
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1579
        gvec = [g1, g2, g3]
1580
        dvec = [d1, d2, d3]
1581
        evec = [e1, e2, e3]
1582
1583
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1584
            delta = 1 if idelta == jdelta else 0
1585
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1586
                                       delta, rtol=0, atol=1e-5)
1587
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1588
                                       delta, rtol=0, atol=1e-5)
1589
        return
1590
1591
    def test_basevectors_apex_invalid_scalar(self):
1592
        """Test warning and fill values for base vectors with bad inputs."""
1593
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1594
        invalid = np.full(shape=(3,), fill_value=np.nan)
1595
1596
        # Get the output and the warnings
1597
        with warnings.catch_warnings(record=True) as warn_rec:
1598
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1599
1600
        for i, bvec in enumerate(basevec):
1601
            if i < 2:
1602
                assert not np.allclose(bvec, invalid[:2])
1603
            else:
1604
                np.testing.assert_allclose(bvec, invalid)
1605
1606
        assert issubclass(warn_rec[-1].category, UserWarning)
1607
        assert 'set to NaN where' in str(warn_rec[-1].message)
1608
        return
1609
1610
1611
class TestApexGetMethods(object):
1612
    """Test the Apex `get` methods."""
1613
    def setup_method(self):
1614
        """Initialize all tests."""
1615
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1616
1617
    def teardown_method(self):
1618
        """Clean up after each test."""
1619
        del self.apex_out
1620
1621
    @pytest.mark.parametrize("alat, aheight",
1622
                             [(10, 507.409702543805),
1623
                              (60, 20313.026999999987),
1624
                              ([10, 60],
1625
                               [507.409702543805, 20313.026999999987]),
1626
                              ([[10], [60]],
1627
                               [[507.409702543805], [20313.026999999987]])])
1628
    def test_get_apex(self, alat, aheight):
1629
        """Test the apex height retrieval results.
1630
1631
        Parameters
1632
        ----------
1633
        alat : int or float
1634
            Apex latitude in degrees N
1635
        aheight : int or float
1636
            Apex height in km
1637
1638
        """
1639
        alt = self.apex_out.get_apex(alat)
1640
        np.testing.assert_allclose(alt, aheight)
1641
        return
1642
1643
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1644
                             [([80], [100], [300], 5.100682377815247e-05),
1645
                              ([80, 80], [100], [300],
1646
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1647
                              ([[80], [80]], [100], [300],
1648
                               [[5.100682377815247e-05],
1649
                                [5.100682377815247e-05]]),
1650
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1651
                               np.array([4.18657154e-05, 5.11118114e-05,
1652
                                         4.91969854e-05, 5.10519207e-05,
1653
                                         4.90054816e-05])),
1654
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1655
    def test_get_babs(self, glat, glon, height, test_bmag):
1656
        """Test the method to get the magnitude of the magnetic field.
1657
1658
        Parameters
1659
        ----------
1660
        glat : list
1661
            List of latitudes in degrees N
1662
        glon : list
1663
            List of longitudes in degrees E
1664
        height : list
1665
            List of heights in km
1666
        test_bmag : float
1667
            Expected B field magnitude
1668
1669
        """
1670
        bmag = self.apex_out.get_babs(glat, glon, height)
1671
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1672
        return
1673
1674
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1675
    def test_get_apex_with_invalid_lat(self, bad_lat):
1676
        """Test get methods raise ValueError for invalid latitudes.
1677
1678
        Parameters
1679
        ----------
1680
        bad_lat : int or float
1681
            Bad input latitude in degrees N
1682
1683
        """
1684
1685
        with pytest.raises(ValueError) as verr:
1686
            self.apex_out.get_apex(bad_lat)
1687
1688
        assert str(verr.value).find("must be in [-90, 90]") > 0
1689
        return
1690
1691
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1692
    def test_get_babs_with_invalid_lat(self, bad_lat):
1693
        """Test get methods raise ValueError for invalid latitudes.
1694
1695
        Parameters
1696
        ----------
1697
        bad_lat : int or float
1698
            Bad input latitude in degrees N
1699
1700
        """
1701
1702
        with pytest.raises(ValueError) as verr:
1703
            self.apex_out.get_babs(bad_lat, 15, 100)
1704
1705
        assert str(verr.value).find("must be in [-90, 90]") > 0
1706
        return
1707
1708
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1709
    def test_get_at_lat_boundary(self, bound_lat):
1710
        """Test get methods at the latitude boundary, with allowed excess.
1711
1712
        Parameters
1713
        ----------
1714
        bound_lat : int or float
1715
            Boundary input latitude in degrees N
1716
1717
        """
1718
        # Get a latitude just beyond the limit
1719
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1720
1721
        # Get the two outputs, slight tolerance outside of boundary allowed
1722
        bound_out = self.apex_out.get_apex(bound_lat)
1723
        excess_out = self.apex_out.get_apex(excess_lat)
1724
1725
        # Test the outputs
1726
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1727
        return
1728
1729
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1730
    def test_get_height_at_equator(self, apex_height):
1731
        """Test that `get_height` returns apex height at equator.
1732
1733
        Parameters
1734
        ----------
1735
        apex_height : float
1736
            Apex height
1737
1738
        """
1739
1740
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1741
        return
1742
1743
    @pytest.mark.parametrize("lat, height", [
1744
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1745
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1746
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1747
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1748
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1749
        (30, 657.2477500000014), (40, -871.8751821247979),
1750
        (50, -2499.1338178752017), (60, -4028.256749999999),
1751
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1752
    def test_get_height_along_fieldline(self, lat, height):
1753
        """Test that `get_height` returns expected height of field line.
1754
1755
        Parameters
1756
        ----------
1757
        lat : float
1758
            Input latitude
1759
        height : float
1760
            Output field-line height for line with apex of 3000 km
1761
1762
        """
1763
1764
        fheight = self.apex_out.get_height(lat, 3000.0)
1765
        assert abs(height - fheight) < 1.0e-7, \
1766
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1767
        return
1768