Completed
Push — main ( ba9c59...ae4cf1 )
by Angeline
19s queued 12s
created

TestApexBasevectorMethods.get_comparison_results()   B

Complexity

Conditions 6

Size

Total Lines 62
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 62
rs 7.9146
c 0
b 0
f 0
cc 6
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    """Test class for the Apex class object."""
72
73
    def setup_method(self):
74
        """Initialize all tests."""
75
        self.apex_out = None
76
        self.test_date = dt.datetime.utcnow()
77
        self.test_refh = 0
78
        self.bad_file = 'foo/path/to/datafile.blah'
79
80
    def teardown_method(self):
81
        """Clean up after each test."""
82
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
83
84
    def eval_date(self):
85
        """Evaluate the times in self.test_date and self.apex_out."""
86
        if isinstance(self.test_date, dt.datetime) \
87
           or isinstance(self.test_date, dt.date):
88
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
89
90
        # Assert the times are the same on the order of tens of seconds.
91
        # Necessary to evaluate the current UTC
92
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
93
        return
94
95
    def eval_refh(self):
96
        """Evaluate the reference height in self.refh and self.apex_out."""
97
        eval_str = "".join(["expected reference height [",
98
                            "{:}] not equal to Apex ".format(self.test_refh),
99
                            "reference height ",
100
                            "[{:}]".format(self.apex_out.refh)])
101
        assert self.test_refh == self.apex_out.refh, eval_str
102
        return
103
104
    def test_init_defaults(self):
105
        """Test Apex class default initialization."""
106
        self.apex_out = apexpy.Apex()
107
        self.eval_date()
108
        self.eval_refh()
109
        return
110
111
    @pytest.mark.parametrize("in_date",
112
                             [2015, 2015.5, dt.date(2015, 1, 1),
113
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
114
    def test_init_date(self, in_date):
115
        """Test Apex class with date initialization.
116
117
        Parameters
118
        ----------
119
        in_date : int, float, dt.date, or dt.datetime
120
            Input date in a variety of formats
121
122
        """
123
        self.test_date = in_date
124
        self.apex_out = apexpy.Apex(date=self.test_date)
125
        self.eval_date()
126
        self.eval_refh()
127
        return
128
129
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
130
    def test_set_epoch(self, new_date):
131
        """Test successful setting of Apex epoch after initialization.
132
133
        Parameters
134
        ----------
135
        new_date : int or float
136
            New date for the Apex class
137
138
        """
139
        # Evaluate the default initialization
140
        self.apex_out = apexpy.Apex()
141
        self.eval_date()
142
        self.eval_refh()
143
144
        # Update the epoch
145
        ref_apex = eval(self.apex_out.__repr__())
146
        self.apex_out.set_epoch(new_date)
147
        assert ref_apex != self.apex_out
148
        self.test_date = new_date
149
        self.eval_date()
150
        return
151
152
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
153
    def test_init_refh(self, in_refh):
154
        """Test Apex class with reference height initialization.
155
156
        Parameters
157
        ----------
158
        in_refh : float
159
            Input reference height in km
160
161
        """
162
        self.test_refh = in_refh
163
        self.apex_out = apexpy.Apex(refh=self.test_refh)
164
        self.eval_date()
165
        self.eval_refh()
166
        return
167
168
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
169
    def test_set_refh(self, new_refh):
170
        """Test the method used to set the reference height after the init.
171
172
        Parameters
173
        ----------
174
        new_refh : float
175
            Reference height in km
176
177
        """
178
        # Verify the defaults are set
179
        self.apex_out = apexpy.Apex(date=self.test_date)
180
        self.eval_date()
181
        self.eval_refh()
182
183
        # Update to a new reference height and test
184
        ref_apex = eval(self.apex_out.__repr__())
185
        self.apex_out.set_refh(new_refh)
186
187
        if self.test_refh == new_refh:
188
            assert ref_apex == self.apex_out
189
        else:
190
            assert ref_apex != self.apex_out
191
            self.test_refh = new_refh
192
        self.eval_refh()
193
        return
194
195
    def test_init_with_bad_datafile(self):
196
        """Test raises IOError with non-existent datafile input."""
197
        with pytest.raises(IOError) as oerr:
198
            apexpy.Apex(datafile=self.bad_file)
199
        assert str(oerr.value).startswith('Data file does not exist')
200
        return
201
202
    def test_init_with_bad_fortranlib(self):
203
        """Test raises IOError with non-existent datafile input."""
204
        with pytest.raises(IOError) as oerr:
205
            apexpy.Apex(fortranlib=self.bad_file)
206
        assert str(oerr.value).startswith('Fortran library does not exist')
207
        return
208
209
    def test_repr_eval(self):
210
        """Test the Apex.__repr__ results."""
211
        # Initialize the apex object
212
        self.apex_out = apexpy.Apex()
213
        self.eval_date()
214
        self.eval_refh()
215
216
        # Get and test the repr string
217
        out_str = self.apex_out.__repr__()
218
        assert out_str.find("apexpy.Apex(") == 0
219
220
        # Test the ability to re-create the apex object from the repr string
221
        new_apex = eval(out_str)
222
        assert new_apex == self.apex_out
223
        return
224
225
    def test_ne_other_class(self):
226
        """Test Apex class inequality to a different class."""
227
        self.apex_out = apexpy.Apex()
228
        self.eval_date()
229
        self.eval_refh()
230
231
        assert self.apex_out != self.test_date
232
        return
233
234
    def test_ne_missing_attr(self):
235
        """Test Apex class inequality when attributes are missing from one."""
236
        self.apex_out = apexpy.Apex()
237
        self.eval_date()
238
        self.eval_refh()
239
        ref_apex = eval(self.apex_out.__repr__())
240
        del ref_apex.RE
241
242
        assert ref_apex != self.apex_out
243
        assert self.apex_out != ref_apex
244
        return
245
246
    def test_eq_missing_attr(self):
247
        """Test Apex class equality when attributes are missing from both."""
248
        self.apex_out = apexpy.Apex()
249
        self.eval_date()
250
        self.eval_refh()
251
        ref_apex = eval(self.apex_out.__repr__())
252
        del ref_apex.RE, self.apex_out.RE
253
254
        assert ref_apex == self.apex_out
255
        return
256
257
    def test_str_eval(self):
258
        """Test the Apex.__str__ results."""
259
        # Initialize the apex object
260
        self.apex_out = apexpy.Apex()
261
        self.eval_date()
262
        self.eval_refh()
263
264
        # Get and test the printed string
265
        out_str = self.apex_out.__str__()
266
        assert out_str.find("Decimal year") > 0
267
        return
268
269
270
class TestApexMethod(object):
271
    """Test the Apex methods."""
272
    def setup_method(self):
273
        """Initialize all tests."""
274
        self.apex_out = apexpy.Apex(date=2000, refh=300)
275
        self.in_lat = 60
276
        self.in_lon = 15
277
        self.in_alt = 100
278
279
    def teardown_method(self):
280
        """Clean up after each test."""
281
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
282
283
    def get_input_args(self, method_name, precision=0.0):
284
        """Set the input arguments for the different Apex methods.
285
286
        Parameters
287
        ----------
288
        method_name : str
289
            Name of the Apex class method
290
        precision : float
291
            Value for the precision (default=0.0)
292
293
        Returns
294
        -------
295
        in_args : list
296
            List of the appropriate input arguments
297
298
        """
299
        in_args = [self.in_lat, self.in_lon, self.in_alt]
300
301
        # Add precision, if needed
302
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
303
                           "_apex2geo"]:
304
            in_args.append(precision)
305
306
        # Add a reference height, if needed
307
        if method_name in ["apxg2all"]:
308
            in_args.append(300)
309
310
        # Add a vector flag, if needed
311
        if method_name in ["apxg2all", "apxg2q"]:
312
            in_args.append(1)
313
314
        return in_args
315
316
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
317
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
318
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
319
                              ("_qd2geo", "apxq2g", slice(None)),
320
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
321
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
322
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
323
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
324
                                  lat, lon):
325
        """Tests Apex/fortran interface consistency for scalars.
326
327
        Parameters
328
        ----------
329
        apex_method : str
330
            Name of the Apex class method to test
331
        fortran_method : str
332
            Name of the Fortran function to test
333
        fslice : slice
334
            Slice used select the appropriate Fortran outputs
335
        lat : int or float
336
            Latitude in degrees N
337
        lon : int or float
338
            Longitude in degrees E
339
340
        """
341
        # Set the input coordinates
342
        self.in_lat = lat
343
        self.in_lon = lon
344
345
        # Get the Apex class method and the fortran function call
346
        apex_func = getattr(self.apex_out, apex_method)
347
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
348
349
        # Get the appropriate input arguments
350
        apex_args = self.get_input_args(apex_method)
351
        fortran_args = self.get_input_args(fortran_method)
352
353
        # Evaluate the equivalent function calls
354
        np.testing.assert_allclose(apex_func(*apex_args),
355
                                   fortran_func(*fortran_args)[fslice])
356
        return
357
358
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
359
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
360
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
361
                              ("_qd2geo", "apxq2g", slice(None)),
362
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
363
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
364
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
365
                                           (180, -180), (-180, 180),
366
                                           (-345, 15), (375, 15)])
367
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
368
                                        fslice, lat, lon1, lon2):
369
        """Tests Apex/fortran interface consistency for longitude rollover.
370
371
        Parameters
372
        ----------
373
        apex_method : str
374
            Name of the Apex class method to test
375
        fortran_method : str
376
            Name of the Fortran function to test
377
        fslice : slice
378
            Slice used select the appropriate Fortran outputs
379
        lat : int or float
380
            Latitude in degrees N
381
        lon1 : int or float
382
            Longitude in degrees E
383
        lon2 : int or float
384
            Equivalent longitude in degrees E
385
386
        """
387
        # Set the fixed input coordinate
388
        self.in_lat = lat
389
390
        # Get the Apex class method and the fortran function call
391
        apex_func = getattr(self.apex_out, apex_method)
392
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
393
394
        # Get the appropriate input arguments
395
        self.in_lon = lon1
396
        apex_args = self.get_input_args(apex_method)
397
398
        self.in_lon = lon2
399
        fortran_args = self.get_input_args(fortran_method)
400
401
        # Evaluate the equivalent function calls
402
        np.testing.assert_allclose(apex_func(*apex_args),
403
                                   fortran_func(*fortran_args)[fslice])
404
        return
405
406
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
407
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
408
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
409
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
410
                              ("_qd2geo", "apxq2g", slice(None)),
411
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
412
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
413
                                 fslice):
414
        """Tests Apex/fortran interface consistency for array input.
415
416
        Parameters
417
        ----------
418
        arr_shape : tuple
419
            Expected output shape
420
        apex_method : str
421
            Name of the Apex class method to test
422
        fortran_method : str
423
            Name of the Fortran function to test
424
        fslice : slice
425
            Slice used select the appropriate Fortran outputs
426
427
        """
428
        # Get the Apex class method and the fortran function call
429
        apex_func = getattr(self.apex_out, apex_method)
430
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
431
432
        # Set up the input arrays
433
        ref_lat = np.array([0, 30, 60, 90])
434
        ref_alt = np.array([100, 200, 300, 400])
435
        self.in_lat = ref_lat.reshape(arr_shape)
436
        self.in_alt = ref_alt.reshape(arr_shape)
437
        apex_args = self.get_input_args(apex_method)
438
439
        # Get the Apex class results
440
        aret = apex_func(*apex_args)
441
442
        # Get the fortran function results
443
        flats = list()
444
        flons = list()
445
446
        for i, lat in enumerate(ref_lat):
447
            self.in_lat = lat
448
            self.in_alt = ref_alt[i]
449
            fortran_args = self.get_input_args(fortran_method)
450
            fret = fortran_func(*fortran_args)[fslice]
451
            flats.append(fret[0])
452
            flons.append(fret[1])
453
454
        flats = np.array(flats)
455
        flons = np.array(flons)
456
457
        # Evaluate results
458
        try:
459
            # This returned value is array of floats
460
            np.testing.assert_allclose(aret[0].astype(float),
461
                                       flats.reshape(arr_shape).astype(float))
462
            np.testing.assert_allclose(aret[1].astype(float),
463
                                       flons.reshape(arr_shape).astype(float))
464
        except ValueError:
465
            # This returned value is array of arrays
466
            alats = aret[0].reshape((4,))
467
            alons = aret[1].reshape((4,))
468
            for i, flat in enumerate(flats):
469
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
470
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
471
472
        return
473
474
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
475
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
476
    def test_geo2apexall_scalar(self, lat, lon):
477
        """Test Apex/fortran geo2apexall interface consistency for scalars.
478
479
        Parameters
480
        ----------
481
        lat : int or float
482
            Latitude in degrees N
483
        long : int or float
484
            Longitude in degrees E
485
486
        """
487
        # Get the Apex and Fortran results
488
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
489
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
490
491
        # Evaluate each element in the results
492
        for aval, fval in zip(aret, fret):
493
            np.testing.assert_allclose(aval, fval)
494
495
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
496
    def test_geo2apexall_array(self, arr_shape):
497
        """Test Apex/fortran geo2apexall interface consistency for arrays.
498
499
        Parameters
500
        ----------
501
        arr_shape : tuple
502
            Expected output shape
503
504
        """
505
        # Set the input
506
        self.in_lat = np.array([0, 30, 60, 90])
507
        self.in_alt = np.array([100, 200, 300, 400])
508
509
        # Get the Apex class results
510
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
511
                                          self.in_lon,
512
                                          self.in_alt.reshape(arr_shape))
513
514
        # For each lat/alt pair, get the Fortran results
515
        fret = list()
516
        for i, lat in enumerate(self.in_lat):
517
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
518
                                                    self.in_alt[i], 300, 1))
519
520
        # Cycle through all returned values
521
        for i, ret in enumerate(aret):
522
            try:
523
                # This returned value is array of floats
524
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
525
                                      fret[3][i]]).reshape(arr_shape)
526
                np.testing.assert_allclose(ret.astype(float),
527
                                           fret_test.astype(float))
528
            except ValueError:
529
                # This returned value is array of arrays
530
                ret = ret.reshape((4,))
531
                for j, single_fret in enumerate(fret):
532
                    np.testing.assert_allclose(ret[j], single_fret[i])
533
        return
534
535
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
536
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
537
    def test_convert_consistency(self, in_coord, out_coord):
538
        """Test the self-consistency of the Apex convert method.
539
540
        Parameters
541
        ----------
542
        in_coord : str
543
            Input coordinate system
544
        out_coord : str
545
            Output coordinate system
546
547
        """
548
        if in_coord == out_coord:
549
            pytest.skip("Test not needed for same src and dest coordinates")
550
551
        # Define the method name
552
        method_name = "2".join([in_coord, out_coord])
553
554
        # Get the method and method inputs
555
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
556
        apex_args = self.get_input_args(method_name)
557
        apex_method = getattr(self.apex_out, method_name)
558
559
        # Define the slice needed to get equivalent output from the named method
560
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
561
562
        # Get output using convert and named method
563
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
564
                                            out_coord, **convert_kwargs)
565
        method_out = apex_method(*apex_args)[mslice]
566
567
        # Compare both outputs, should be identical
568
        np.testing.assert_allclose(convert_out, method_out)
569
        return
570
571
    @pytest.mark.parametrize("bound_lat", [90, -90])
572
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
573
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
574
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
575
        """Test the conversion at the latitude boundary, with allowed excess.
576
577
        Parameters
578
        ----------
579
        bound_lat : int or float
580
            Boundary latitude in degrees N
581
        in_coord : str
582
            Input coordinate system
583
        out_coord : str
584
            Output coordinate system
585
586
        """
587
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
588
589
        # Get the two outputs, slight tolerance outside of boundary allowed
590
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
591
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
592
593
        # Test the outputs
594
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
595
        return
596
597
    def test_convert_qd2apex_at_equator(self):
598
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
599
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
600
                                       height=320.0)
601
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
602
                                          dest='apex', height=320.0)
603
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
604
        return
605
606
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
607
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
608
    def test_convert_withnan(self, src, dest):
609
        """Test Apex.convert success with NaN input.
610
611
        Parameters
612
        ----------
613
        src : str
614
            Input coordinate system
615
        dest : str
616
            Output coordinate system
617
618
        """
619
        if src == dest:
620
            pytest.skip("Test not needed for same src and dest coordinates")
621
622
        num_nans = 5
623
        in_loc = np.arange(0, 10, dtype=float)
624
        in_loc[:num_nans] = np.nan
625
626
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
627
628
        for out in out_loc:
629
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
630
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
631
632
        return
633
634
    @pytest.mark.parametrize("bad_lat", [91, -91])
635
    def test_convert_invalid_lat(self, bad_lat):
636
        """Test convert raises ValueError for invalid latitudes.
637
638
        Parameters
639
        ----------
640
        bad_lat : int or float
641
            Latitude ouside the supported range in degrees N
642
643
        """
644
645
        with pytest.raises(ValueError) as verr:
646
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
647
648
        assert str(verr.value).find("must be in [-90, 90]") > 0
649
        return
650
651
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
652
                                        ("geo", "mlt")])
653
    def test_convert_invalid_transformation(self, coords):
654
        """Test raises NotImplementedError for bad coordinates.
655
656
        Parameters
657
        ----------
658
        coords : tuple
659
            Tuple specifying the input and output coordinate systems
660
661
        """
662
        if "mlt" in coords:
663
            estr = "datetime must be given for MLT calculations"
664
        else:
665
            estr = "Unknown coordinate transformation"
666
667
        with pytest.raises(ValueError) as verr:
668
            self.apex_out.convert(0, 0, *coords)
669
670
        assert str(verr).find(estr) >= 0
671
        return
672
673
    @pytest.mark.parametrize("method_name, out_comp",
674
                             [("geo2apex",
675
                               (55.94841766357422, 94.10684204101562)),
676
                              ("apex2geo",
677
                               (51.476322174072266, -66.22817993164062,
678
                                5.727287771151168e-06)),
679
                              ("geo2qd",
680
                               (56.531288146972656, 94.10684204101562)),
681
                              ("apex2qd", (60.498401178276744, 15.0)),
682
                              ("qd2apex", (59.49138097045895, 15.0))])
683
    def test_method_scalar_input(self, method_name, out_comp):
684
        """Test the user method against set values with scalars.
685
686
        Parameters
687
        ----------
688
        method_name : str
689
            Apex class method to be tested
690
        out_comp : tuple of floats
691
            Expected output values
692
693
        """
694
        # Get the desired methods
695
        user_method = getattr(self.apex_out, method_name)
696
697
        # Get the user output
698
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
699
700
        # Evaluate the user output
701
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
702
703
        for out_val in user_out:
704
            assert np.asarray(out_val).shape == (), "output is not a scalar"
705
        return
706
707
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
708
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
709
    @pytest.mark.parametrize("method_args, out_shape",
710
                             [([[60, 60], 15, 100], (2,)),
711
                              ([60, [15, 15], 100], (2,)),
712
                              ([60, 15, [100, 100]], (2,)),
713
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
714
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
715
                                    out_shape):
716
        """Test the user method with inputs that require some broadcasting.
717
718
        Parameters
719
        ----------
720
        in_coord : str
721
            Input coordiante system
722
        out_coord : str
723
            Output coordiante system
724
        method_args : list
725
            List of input arguments
726
        out_shape : tuple
727
            Expected shape of output values
728
729
        """
730
        if in_coord == out_coord:
731
            pytest.skip("Test not needed for same src and dest coordinates")
732
733
        # Get the desired methods
734
        method_name = "2".join([in_coord, out_coord])
735
        user_method = getattr(self.apex_out, method_name)
736
737
        # Get the user output
738
        user_out = user_method(*method_args)
739
740
        # Evaluate the user output
741
        for out_val in user_out:
742
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
743
            assert out_val.shape == out_shape
744
        return
745
746
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
747
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
748
    @pytest.mark.parametrize("bad_lat", [91, -91])
749
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
750
        """Test convert raises ValueError for invalid latitudes.
751
752
        Parameters
753
        ----------
754
        in_coord : str
755
            Input coordiante system
756
        out_coord : str
757
            Output coordiante system
758
        bad_lat : int
759
            Latitude in degrees N that is out of bounds
760
761
        """
762
        if in_coord == out_coord:
763
            pytest.skip("Test not needed for same src and dest coordinates")
764
765
        # Get the desired methods
766
        method_name = "2".join([in_coord, out_coord])
767
        user_method = getattr(self.apex_out, method_name)
768
769
        with pytest.raises(ValueError) as verr:
770
            user_method(bad_lat, 15, 100)
771
772
        assert str(verr.value).find("must be in [-90, 90]") > 0
773
        return
774
775
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
776
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
777
    @pytest.mark.parametrize("bound_lat", [90, -90])
778
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
779
        """Test user methods at the latitude boundary, with allowed excess.
780
781
        Parameters
782
        ----------
783
        in_coord : str
784
            Input coordiante system
785
        out_coord : str
786
            Output coordiante system
787
        bad_lat : int
788
            Latitude in degrees N that is at the limits of the boundary
789
790
        """
791
        if in_coord == out_coord:
792
            pytest.skip("Test not needed for same src and dest coordinates")
793
794
        # Get the desired methods
795
        method_name = "2".join([in_coord, out_coord])
796
        user_method = getattr(self.apex_out, method_name)
797
798
        # Get a latitude just beyond the limit
799
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
800
801
        # Get the two outputs, slight tolerance outside of boundary allowed
802
        bound_out = user_method(bound_lat, 0, 100)
803
        excess_out = user_method(excess_lat, 0, 100)
804
805
        # Test the outputs
806
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
807
        return
808
809
    def test_geo2apex_undefined_warning(self):
810
        """Test geo2apex warning and fill values for an undefined location."""
811
812
        # Update the apex object
813
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
814
815
        # Get the output and the warnings
816
        with warnings.catch_warnings(record=True) as warn_rec:
817
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
818
819
        assert np.isnan(user_lat)
820
        assert np.isfinite(user_lon)
821
        assert len(warn_rec) == 1
822
        assert issubclass(warn_rec[-1].category, UserWarning)
823
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
824
        return
825
826
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
827
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
828
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
829
        """Test quasi-dipole success with a height close to the reference.
830
831
        Parameters
832
        ----------
833
        method_name : str
834
            Apex class method name to be tested
835
        delta_h : float
836
            tolerance for height in km
837
838
        """
839
        qd_method = getattr(self.apex_out, method_name)
840
        in_args = [0, 15, self.apex_out.refh + delta_h]
841
        out_coords = qd_method(*in_args)
842
843
        for i, out_val in enumerate(out_coords):
844
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
845
        return
846
847
    @pytest.mark.parametrize("method_name, hinc, msg",
848
                             [("apex2qd", 1.0, "is > apex height"),
849
                              ("qd2apex", -1.0, "is < reference height")])
850
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
851
        """Quasi-dipole raises ApexHeightError when height above reference.
852
853
        Parameters
854
        ----------
855
        method_name : str
856
            Apex class method name to be tested
857
        hinc : float
858
            Height increment in km
859
        msg : str
860
            Expected output message
861
862
        """
863
        qd_method = getattr(self.apex_out, method_name)
864
865
        with pytest.raises(apexpy.ApexHeightError) as aerr:
866
            qd_method(0, 15, self.apex_out.refh + hinc)
867
868
        assert str(aerr).find(msg) > 0
869
        return
870
871
872
class TestApexMLTMethods(object):
873
    """Test the Apex Magnetic Local Time (MLT) methods."""
874
    def setup_method(self):
875
        """Initialize all tests."""
876
        self.apex_out = apexpy.Apex(date=2000, refh=300)
877
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
878
879
    def teardown_method(self):
880
        """Clean up after each test."""
881
        del self.apex_out, self.in_time
882
883
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
884
    def test_convert_to_mlt(self, in_coord):
885
        """Test the conversions to MLT using Apex convert.
886
887
        Parameters
888
        ----------
889
        in_coord : str
890
            Input coordinate system
891
892
        """
893
894
        # Get the magnetic longitude from the appropriate method
895
        if in_coord == "geo":
896
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
897
            mlon = apex_method(60, 15, 100)[1]
898
        else:
899
            mlon = 15
900
901
        # Get the output MLT values
902
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
903
                                            height=100, ssheight=2e5,
904
                                            datetime=self.in_time)[1]
905
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
906
907
        # Test the outputs
908
        np.testing.assert_allclose(convert_mlt, method_mlt)
909
        return
910
911
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
912
    def test_convert_mlt_to_lon(self, out_coord):
913
        """Test the conversions from MLT using Apex convert.
914
915
        Parameters
916
        ----------
917
        out_coord : str
918
            Output coordinate system
919
920
        """
921
        # Get the output longitudes
922
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
923
                                            height=100, ssheight=2e5,
924
                                            datetime=self.in_time,
925
                                            precision=1e-2)
926
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
927
928
        if out_coord == "geo":
929
            method_out = self.apex_out.apex2geo(60, mlon, 100,
930
                                                precision=1e-2)[:-1]
931
        elif out_coord == "qd":
932
            method_out = self.apex_out.apex2qd(60, mlon, 100)
933
        else:
934
            method_out = (60, mlon)
935
936
        # Evaluate the outputs
937
        np.testing.assert_allclose(convert_out, method_out)
938
        return
939
940
    def test_convert_geo2mlt_nodate(self):
941
        """Test convert from geo to MLT raises ValueError with no datetime."""
942
        with pytest.raises(ValueError):
943
            self.apex_out.convert(60, 15, 'geo', 'mlt')
944
        return
945
946
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
947
                             [({}, 23.019629923502603),
948
                              ({"ssheight": 100000}, 23.026712036132814)])
949
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
950
        """Test mlon2mlt with scalar inputs.
951
952
        Parameters
953
        ----------
954
        mlon_kwargs : dict
955
            Input kwargs
956
        test_mlt : float
957
            Output MLT in hours
958
959
        """
960
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
961
962
        np.testing.assert_allclose(mlt, test_mlt)
963
        assert np.asarray(mlt).shape == ()
964
        return
965
966
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
967
                             [({}, 14.705535888671875),
968
                              ({"ssheight": 100000}, 14.599319458007812)])
969
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
970
        """Test mlt2mlon with scalar inputs.
971
972
        Parameters
973
        ----------
974
        mlt_kwargs : dict
975
            Input kwargs
976
        test_mlon : float
977
            Output longitude in degrees E
978
979
        """
980
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
981
982
        np.testing.assert_allclose(mlon, test_mlon)
983
        assert np.asarray(mlon).shape == ()
984
        return
985
986
    @pytest.mark.parametrize("mlon,test_mlt",
987
                             [([0, 180], [23.019261, 11.019261]),
988
                              (np.array([0, 180]), [23.019261, 11.019261]),
989
                              (np.array([[0], [180]]),
990
                               np.array([[23.019261], [11.019261]])),
991
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
992
                                                      [23.019261, 11.019261]]),
993
                              (range(0, 361, 30),
994
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
995
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
996
                                19.01963, 21.01963, 23.01963])])
997
    def test_mlon2mlt_array(self, mlon, test_mlt):
998
        """Test mlon2mlt with array inputs.
999
1000
        Parameters
1001
        ----------
1002
        mlon : array-like
1003
            Input longitudes in degrees E
1004
        test_mlt : float
1005
            Output MLT in hours
1006
1007
        """
1008
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1009
1010
        assert mlt.shape == np.asarray(test_mlt).shape
1011
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1012
        return
1013
1014
    @pytest.mark.parametrize("mlt,test_mlon",
1015
                             [([0, 12], [14.705551, 194.705551]),
1016
                              (np.array([0, 12]), [14.705551, 194.705551]),
1017
                              (np.array([[0], [12]]),
1018
                               np.array([[14.705551], [194.705551]])),
1019
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1020
                                                    [14.705551, 194.705551]]),
1021
                              (range(0, 25, 2),
1022
                               [14.705551, 44.705551, 74.705551, 104.705551,
1023
                                134.705551, 164.705551, 194.705551, 224.705551,
1024
                                254.705551, 284.705551, 314.705551, 344.705551,
1025
                                14.705551])])
1026
    def test_mlt2mlon_array(self, mlt, test_mlon):
1027
        """Test mlt2mlon with array inputs.
1028
1029
        Parameters
1030
        ----------
1031
        mlt : array-like
1032
            Input MLT in hours
1033
        test_mlon : float
1034
            Output longitude in degrees E
1035
1036
        """
1037
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1038
1039
        assert mlon.shape == np.asarray(test_mlon).shape
1040
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1041
        return
1042
1043
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1044
    def test_mlon2mlt_diffdates(self, method_name):
1045
        """Test that MLT varies with universal time.
1046
1047
        Parameters
1048
        ----------
1049
        method_name : str
1050
            Name of Apex class method to be tested
1051
1052
        """
1053
        apex_method = getattr(self.apex_out, method_name)
1054
        mlt1 = apex_method(0, self.in_time)
1055
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1056
1057
        assert mlt1 != mlt2
1058
        return
1059
1060
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1061
    def test_mlon2mlt_offset(self, mlt_offset):
1062
        """Test the time wrapping logic for the MLT.
1063
1064
        Parameters
1065
        ----------
1066
        mlt_offset : float
1067
            MLT offset in hours
1068
1069
        """
1070
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1071
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1072
                                      self.in_time) + mlt_offset
1073
1074
        np.testing.assert_allclose(mlt1, mlt2)
1075
        return
1076
1077
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1078
    def test_mlt2mlon_offset(self, mlon_offset):
1079
        """Test the time wrapping logic for the magnetic longitude.
1080
1081
        Parameters
1082
        ----------
1083
        mlt_offset : float
1084
            MLT offset in hours
1085
1086
        """
1087
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1088
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1089
                                       self.in_time) - mlon_offset
1090
1091
        np.testing.assert_allclose(mlon1, mlon2)
1092
        return
1093
1094
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1095
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1096
    def test_convert_and_return(self, order, start_val):
1097
        """Test the conversion to magnetic longitude or MLT and back again.
1098
1099
        Parameters
1100
        ----------
1101
        order : list
1102
            List of strings specifying the order to run functions
1103
        start_val : int or float
1104
            Input value
1105
1106
        """
1107
        first_method = getattr(self.apex_out, "2".join(order))
1108
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1109
1110
        middle_val = first_method(start_val, self.in_time)
1111
        end_val = second_method(middle_val, self.in_time)
1112
1113
        np.testing.assert_allclose(start_val, end_val)
1114
        return
1115
1116
1117
class TestApexMapMethods(object):
1118
    """Test the Apex height mapping methods."""
1119
    def setup_method(self):
1120
        """Initialize all tests."""
1121
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1122
1123
    def teardown_method(self):
1124
        """Clean up after each test."""
1125
        del self.apex_out
1126
1127
    @pytest.mark.parametrize("in_args,test_mapped",
1128
                             [([60, 15, 100, 10000],
1129
                               [31.841466903686523, 17.916635513305664,
1130
                                1.7075473124350538e-6]),
1131
                              ([30, 170, 100, 500, False, 1e-2],
1132
                               [25.727270126342773, 169.60546875,
1133
                                0.00017573432705830783]),
1134
                              ([60, 15, 100, 10000, True],
1135
                               [-25.424888610839844, 27.310426712036133,
1136
                                1.2074182222931995e-6]),
1137
                              ([30, 170, 100, 500, True, 1e-2],
1138
                               [-13.76642894744873, 164.24259948730469,
1139
                                0.00056820799363777041])])
1140
    def test_map_to_height(self, in_args, test_mapped):
1141
        """Test the map_to_height function.
1142
1143
        Parameters
1144
        ----------
1145
        in_args : list
1146
            List of input arguments
1147
        test_mapped : list
1148
            List of expected outputs
1149
1150
        """
1151
        mapped = self.apex_out.map_to_height(*in_args)
1152
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1153
        return
1154
1155
    def test_map_to_height_same_height(self):
1156
        """Test the map_to_height function when mapping to same height."""
1157
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1158
                                             precision=1e-10)
1159
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1160
                                   rtol=1e-5, atol=1e-5)
1161
        return
1162
1163
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1164
    @pytest.mark.parametrize('ivec', range(0, 4))
1165
    def test_map_to_height_array_location(self, arr_shape, ivec):
1166
        """Test map_to_height with array input.
1167
1168
        Parameters
1169
        ----------
1170
        arr_shape : tuple
1171
            Expected array shape
1172
        ivec : int
1173
            Input argument index for vectorized input
1174
1175
        """
1176
        # Set the base input and output values
1177
        in_args = [60, 15, 100, 100]
1178
        test_mapped = [60, 15.00000381, 0.0]
1179
1180
        # Update inputs for one vectorized value
1181
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1182
1183
        # Calculate and test function
1184
        mapped = self.apex_out.map_to_height(*in_args)
1185
        for i, test_val in enumerate(test_mapped):
1186
            assert mapped[i].shape == arr_shape
1187
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1188
                                       atol=1e-5)
1189
        return
1190
1191
    @pytest.mark.parametrize("method_name,in_args",
1192
                             [("map_to_height", [0, 15, 100, 10000]),
1193
                              ("map_E_to_height",
1194
                               [0, 15, 100, 10000, [1, 2, 3]]),
1195
                              ("map_V_to_height",
1196
                               [0, 15, 100, 10000, [1, 2, 3]])])
1197
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1198
        """Test map_to_height raises ApexHeightError.
1199
1200
        Parameters
1201
        ----------
1202
        method_name : str
1203
            Name of the Apex class method to test
1204
        in_args : list
1205
            List of input arguments
1206
1207
        """
1208
        apex_method = getattr(self.apex_out, method_name)
1209
1210
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1211
            apex_method(*in_args)
1212
1213
        assert aerr.match("is > apex height")
1214
        return
1215
1216
    @pytest.mark.parametrize("method_name",
1217
                             ["map_E_to_height", "map_V_to_height"])
1218
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1219
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1220
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1221
        """Test height mapping of E/V with baddly shaped input raises Error.
1222
1223
        Parameters
1224
        ----------
1225
        method_name : str
1226
            Name of the Apex class method to test
1227
        ev_input : list
1228
            E/V input arguments
1229
1230
        """
1231
        apex_method = getattr(self.apex_out, method_name)
1232
        in_args = [60, 15, 100, 500, ev_input]
1233
        with pytest.raises(ValueError) as verr:
1234
            apex_method(*in_args)
1235
1236
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1237
        return
1238
1239
    def test_mapping_EV_bad_flag(self):
1240
        """Test _map_EV_to_height raises error for bad data type flag."""
1241
        with pytest.raises(ValueError) as verr:
1242
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1243
1244
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1245
        return
1246
1247
    @pytest.mark.parametrize("in_args,test_mapped",
1248
                             [([60, 15, 100, 500, [1, 2, 3]],
1249
                               [0.71152183, 2.35624876, 0.57260784]),
1250
                              ([60, 15, 100, 500, [2, 3, 4]],
1251
                               [1.56028502, 3.43916636, 0.78235384]),
1252
                              ([60, 15, 100, 1000, [1, 2, 3]],
1253
                               [0.67796492, 2.08982134, 0.55860785]),
1254
                              ([60, 15, 200, 500, [1, 2, 3]],
1255
                               [0.72377397, 2.42737471, 0.59083726]),
1256
                              ([60, 30, 100, 500, [1, 2, 3]],
1257
                               [0.68626344, 2.37530133, 0.60060124]),
1258
                              ([70, 15, 100, 500, [1, 2, 3]],
1259
                               [0.72760378, 2.18082305, 0.29141979])])
1260
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1261
        """Test mapping of E-field to a specified height.
1262
1263
        Parameters
1264
        ----------
1265
        in_args : list
1266
            List of input arguments
1267
        test_mapped : list
1268
            List of expected outputs
1269
1270
        """
1271
        mapped = self.apex_out.map_E_to_height(*in_args)
1272
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1273
        return
1274
1275
    @pytest.mark.parametrize('ev_flag, test_mapped',
1276
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1277
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1278
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1279
    @pytest.mark.parametrize('ivec', range(0, 5))
1280
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1281
                                             arr_shape, ivec):
1282
        """Test mapping of E-field/drift to a specified height with arrays.
1283
1284
        Parameters
1285
        ----------
1286
        ev_flag : str
1287
            Character flag specifying whether to run 'E' or 'V' methods
1288
        test_mapped : list
1289
            List of expected outputs
1290
        arr_shape : tuple
1291
            Shape of the expected output
1292
        ivec : int
1293
            Index of the expected output
1294
1295
        """
1296
        # Set the base input and output values
1297
        eshape = list(arr_shape)
1298
        eshape.insert(0, 3)
1299
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1300
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1301
1302
        # Update inputs for one vectorized value if this is a location input
1303
        if ivec < 4:
1304
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1305
1306
        # Get the mapped output
1307
        apex_method = getattr(self.apex_out,
1308
                              "map_{:s}_to_height".format(ev_flag))
1309
        mapped = apex_method(*in_args)
1310
1311
        # Test the results
1312
        for i, test_val in enumerate(test_mapped):
1313
            assert mapped[i].shape == arr_shape
1314
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1315
        return
1316
1317
    @pytest.mark.parametrize("in_args,test_mapped",
1318
                             [([60, 15, 100, 500, [1, 2, 3]],
1319
                               [0.81971957, 2.84512495, 0.69545001]),
1320
                              ([60, 15, 100, 500, [2, 3, 4]],
1321
                               [1.83027746, 4.14346436, 0.94764179]),
1322
                              ([60, 15, 100, 1000, [1, 2, 3]],
1323
                               [0.92457698, 3.14997661, 0.85135187]),
1324
                              ([60, 15, 200, 500, [1, 2, 3]],
1325
                               [0.80388262, 2.79321504, 0.68285158]),
1326
                              ([60, 30, 100, 500, [1, 2, 3]],
1327
                               [0.76141245, 2.87884673, 0.73655941]),
1328
                              ([70, 15, 100, 500, [1, 2, 3]],
1329
                               [0.84681866, 2.5925821, 0.34792655])])
1330
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1331
        """Test mapping of velocity to a specified height.
1332
1333
        Parameters
1334
        ----------
1335
        in_args : list
1336
            List of input arguments
1337
        test_mapped : list
1338
            List of expected outputs
1339
1340
        """
1341
        mapped = self.apex_out.map_V_to_height(*in_args)
1342
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1343
        return
1344
1345
1346
class TestApexBasevectorMethods(object):
1347
    """Test the Apex height base vector methods."""
1348
    def setup_method(self):
1349
        """Initialize all tests."""
1350
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1351
        self.lat = 60
1352
        self.lon = 15
1353
        self.height = 100
1354
        self.test_basevec = None
1355
1356
    def teardown_method(self):
1357
        """Clean up after each test."""
1358
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1359
1360
    def get_comparison_results(self, bv_coord, coords, precision):
1361
        """Get the base vector results using the hidden function for comparison.
1362
1363
        Parameters
1364
        ----------
1365
        bv_coord : str
1366
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1367
            or 'bvectors_apex'
1368
        coords : str
1369
            Expects one of 'geo', 'apex', or 'qd'
1370
        precision : float
1371
            Float specifiying precision
1372
1373
        """
1374
        if coords == "geo":
1375
            glat = self.lat
1376
            glon = self.lon
1377
        else:
1378
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1379
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1380
                                        precision=precision)
1381
1382
        if bv_coord == 'qd':
1383
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1384
        elif bv_coord == 'apex':
1385
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1386
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1387
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1388
        else:
1389
            # These are set results that need to be updated with IGRF
1390
            if coords == "geo":
1391
                self.test_basevec = (
1392
                    np.array([4.42368795e-05, 4.42368795e-05]),
1393
                    np.array([[0.01047826, 0.01047826],
1394
                              [0.33089194, 0.33089194],
1395
                              [-1.04941, -1.04941]]),
1396
                    np.array([5.3564698e-05, 5.3564698e-05]),
1397
                    np.array([[0.00865356, 0.00865356],
1398
                              [0.27327004, 0.27327004],
1399
                              [-0.8666646, -0.8666646]]))
1400
            elif coords == "apex":
1401
                self.test_basevec = (
1402
                    np.array([4.48672735e-05, 4.48672735e-05]),
1403
                    np.array([[-0.12510721, -0.12510721],
1404
                              [0.28945938, 0.28945938],
1405
                              [-1.1505738, -1.1505738]]),
1406
                    np.array([6.38577444e-05, 6.38577444e-05]),
1407
                    np.array([[-0.08790194, -0.08790194],
1408
                              [0.2033779, 0.2033779],
1409
                              [-0.808408, -0.808408]]))
1410
            else:
1411
                self.test_basevec = (
1412
                    np.array([4.46348578e-05, 4.46348578e-05]),
1413
                    np.array([[-0.12642345, -0.12642345],
1414
                              [0.29695055, 0.29695055],
1415
                              [-1.1517885, -1.1517885]]),
1416
                    np.array([6.38626285e-05, 6.38626285e-05]),
1417
                    np.array([[-0.08835986, -0.08835986],
1418
                              [0.20754464, 0.20754464],
1419
                              [-0.8050078, -0.8050078]]))
1420
1421
        return
1422
1423
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1424
    @pytest.mark.parametrize("coords,precision",
1425
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1426
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1427
        """Test the base vector calculations with scalars.
1428
1429
        Parameters
1430
        ----------
1431
        bv_coord : str
1432
            Name of the input coordinate system
1433
        coords : str
1434
            Name of the output coordinate system
1435
        precision : float
1436
            Level of run precision requested
1437
1438
        """
1439
        # Get the base vectors
1440
        base_method = getattr(self.apex_out,
1441
                              "basevectors_{:s}".format(bv_coord))
1442
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1443
                              precision=precision)
1444
        self.get_comparison_results(bv_coord, coords, precision)
1445
        if bv_coord == "apex":
1446
            basevec = list(basevec)
1447
            for i in range(4):
1448
                # Not able to compare indices 2, 3, 4, and 5
1449
                basevec.pop(2)
1450
1451
        # Test the results
1452
        for i, vec in enumerate(basevec):
1453
            np.testing.assert_allclose(vec, self.test_basevec[i])
1454
        return
1455
1456
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1457
    def test_basevectors_scalar_shape(self, bv_coord):
1458
        """Test the shape of the scalar output.
1459
1460
        Parameters
1461
        ----------
1462
        bv_coord : str
1463
            Name of the input coordinate system
1464
1465
        """
1466
        base_method = getattr(self.apex_out,
1467
                              "basevectors_{:s}".format(bv_coord))
1468
        basevec = base_method(self.lat, self.lon, self.height)
1469
1470
        for i, vec in enumerate(basevec):
1471
            if i < 2:
1472
                assert vec.shape == (2,)
1473
            else:
1474
                assert vec.shape == (3,)
1475
        return
1476
1477
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1478
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1479
    @pytest.mark.parametrize("ivec", range(3))
1480
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1481
        """Test the output shape for array inputs.
1482
1483
        Parameters
1484
        ----------
1485
        arr_shape : tuple
1486
            Expected output shape
1487
        bv_coord : str
1488
            Name of the input coordinate system
1489
        ivec : int
1490
            Index of the evaluated output value
1491
1492
        """
1493
        # Define the input arguments
1494
        in_args = [self.lat, self.lon, self.height]
1495
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1496
1497
        # Get the basevectors
1498
        base_method = getattr(self.apex_out,
1499
                              "basevectors_{:s}".format(bv_coord))
1500
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1501
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1502
        if bv_coord == "apex":
1503
            basevec = list(basevec)
1504
            for i in range(4):
1505
                # Not able to compare indices 2, 3, 4, and 5
1506
                basevec.pop(2)
1507
1508
        # Evaluate the shape and the values
1509
        for i, vec in enumerate(basevec):
1510
            test_shape = list(arr_shape)
1511
            test_shape.insert(0, 2 if i < 2 else 3)
1512
            assert vec.shape == tuple(test_shape)
1513
            assert np.all(self.test_basevec[i][0] == vec[0])
1514
            assert np.all(self.test_basevec[i][1] == vec[1])
1515
        return
1516
1517
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1518
    def test_bvectors_apex(self, coords):
1519
        """Test the bvectors_apex method.
1520
1521
        Parameters
1522
        ----------
1523
        coords : str
1524
            Name of the coordiante system
1525
1526
        """
1527
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1528
                   [self.height, self.height]]
1529
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1530
1531
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1532
                                              precision=1e-10)
1533
        for i, vec in enumerate(basevec):
1534
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1535
                                                 decimal=5)
1536
        return
1537
1538
    def test_basevectors_apex_extra_values(self):
1539
        """Test specific values in the apex base vector output."""
1540
        # Set the testing arrays
1541
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1542
                             np.array([0.939012, 0.073416, -0.07342]),
1543
                             np.array([0.055389, 1.004155, 0.257594]),
1544
                             np.array([0, 0, 1.065135])]
1545
1546
        # Get the desired output
1547
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1548
1549
        # Test the values not covered by `test_basevectors_scalar`
1550
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1551
            np.testing.assert_allclose(basevec[ibase],
1552
                                       self.test_basevec[itest], rtol=1e-4)
1553
        return
1554
1555
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1556
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1557
    def test_basevectors_apex_delta(self, lat, lon):
1558
        """Test that vectors are calculated correctly.
1559
1560
        Parameters
1561
        ----------
1562
        lat : int or float
1563
            Latitude in degrees N
1564
        lon : int or float
1565
            Longitude in degrees E
1566
1567
        """
1568
        # Get the apex base vectors and sort them for easy testing
1569
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1570
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1571
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1572
        gvec = [g1, g2, g3]
1573
        dvec = [d1, d2, d3]
1574
        evec = [e1, e2, e3]
1575
1576
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1577
            delta = 1 if idelta == jdelta else 0
1578
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1579
                                       delta, rtol=0, atol=1e-5)
1580
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1581
                                       delta, rtol=0, atol=1e-5)
1582
        return
1583
1584
    def test_basevectors_apex_invalid_scalar(self):
1585
        """Test warning and fill values for base vectors with bad inputs."""
1586
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1587
        invalid = np.full(shape=(3,), fill_value=np.nan)
1588
1589
        # Get the output and the warnings
1590
        with warnings.catch_warnings(record=True) as warn_rec:
1591
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1592
1593
        for i, bvec in enumerate(basevec):
1594
            if i < 2:
1595
                assert not np.allclose(bvec, invalid[:2])
1596
            else:
1597
                np.testing.assert_allclose(bvec, invalid)
1598
1599
        assert issubclass(warn_rec[-1].category, UserWarning)
1600
        assert 'set to NaN where' in str(warn_rec[-1].message)
1601
        return
1602
1603
1604
class TestApexGetMethods(object):
1605
    """Test the Apex `get` methods."""
1606
    def setup_method(self):
1607
        """Initialize all tests."""
1608
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1609
1610
    def teardown_method(self):
1611
        """Clean up after each test."""
1612
        del self.apex_out
1613
1614
    @pytest.mark.parametrize("alat, aheight",
1615
                             [(10, 507.409702543805),
1616
                              (60, 20313.026999999987),
1617
                              ([10, 60],
1618
                               [507.409702543805, 20313.026999999987]),
1619
                              ([[10], [60]],
1620
                               [[507.409702543805], [20313.026999999987]])])
1621
    def test_get_apex(self, alat, aheight):
1622
        """Test the apex height retrieval results.
1623
1624
        Parameters
1625
        ----------
1626
        alat : int or float
1627
            Apex latitude in degrees N
1628
        aheight : int or float
1629
            Apex height in km
1630
1631
        """
1632
        alt = self.apex_out.get_apex(alat)
1633
        np.testing.assert_allclose(alt, aheight)
1634
        return
1635
1636
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1637
                             [([80], [100], [300], 5.100682377815247e-05),
1638
                              ([80, 80], [100], [300],
1639
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1640
                              ([[80], [80]], [100], [300],
1641
                               [[5.100682377815247e-05],
1642
                                [5.100682377815247e-05]]),
1643
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1644
                               np.array([4.18657154e-05, 5.11118114e-05,
1645
                                         4.91969854e-05, 5.10519207e-05,
1646
                                         4.90054816e-05])),
1647
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1648
    def test_get_babs(self, glat, glon, height, test_bmag):
1649
        """Test the method to get the magnitude of the magnetic field.
1650
1651
        Parameters
1652
        ----------
1653
        glat : list
1654
            List of latitudes in degrees N
1655
        glon : list
1656
            List of longitudes in degrees E
1657
        height : list
1658
            List of heights in km
1659
        test_bmag : float
1660
            Expected B field magnitude
1661
1662
        """
1663
        bmag = self.apex_out.get_babs(glat, glon, height)
1664
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1665
        return
1666
1667
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1668
    def test_get_apex_with_invalid_lat(self, bad_lat):
1669
        """Test get methods raise ValueError for invalid latitudes.
1670
1671
        Parameters
1672
        ----------
1673
        bad_lat : int or float
1674
            Bad input latitude in degrees N
1675
1676
        """
1677
1678
        with pytest.raises(ValueError) as verr:
1679
            self.apex_out.get_apex(bad_lat)
1680
1681
        assert str(verr.value).find("must be in [-90, 90]") > 0
1682
        return
1683
1684
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1685
    def test_get_babs_with_invalid_lat(self, bad_lat):
1686
        """Test get methods raise ValueError for invalid latitudes.
1687
1688
        Parameters
1689
        ----------
1690
        bad_lat : int or float
1691
            Bad input latitude in degrees N
1692
1693
        """
1694
1695
        with pytest.raises(ValueError) as verr:
1696
            self.apex_out.get_babs(bad_lat, 15, 100)
1697
1698
        assert str(verr.value).find("must be in [-90, 90]") > 0
1699
        return
1700
1701
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1702
    def test_get_at_lat_boundary(self, bound_lat):
1703
        """Test get methods at the latitude boundary, with allowed excess.
1704
1705
        Parameters
1706
        ----------
1707
        bound_lat : int or float
1708
            Boundary input latitude in degrees N
1709
1710
        """
1711
        # Get a latitude just beyond the limit
1712
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1713
1714
        # Get the two outputs, slight tolerance outside of boundary allowed
1715
        bound_out = self.apex_out.get_apex(bound_lat)
1716
        excess_out = self.apex_out.get_apex(excess_lat)
1717
1718
        # Test the outputs
1719
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1720
        return
1721
1722
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1723
    def test_get_height_at_equator(self, apex_height):
1724
        """Test that `get_height` returns apex height at equator.
1725
1726
        Parameters
1727
        ----------
1728
        apex_height : float
1729
            Apex height
1730
1731
        """
1732
1733
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1734
        return
1735
1736
    @pytest.mark.parametrize("lat, height", [
1737
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1738
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1739
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1740
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1741
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1742
        (30, 657.2477500000014), (40, -871.8751821247979),
1743
        (50, -2499.1338178752017), (60, -4028.256749999999),
1744
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1745
    def test_get_height_along_fieldline(self, lat, height):
1746
        """Test that `get_height` returns expected height of field line.
1747
1748
        Parameters
1749
        ----------
1750
        lat : float
1751
            Input latitude
1752
        height : float
1753
            Output field-line height for line with apex of 3000 km
1754
1755
        """
1756
1757
        fheight = self.apex_out.get_height(lat, 3000.0)
1758
        assert abs(height - fheight) < 1.0e-7, \
1759
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1760
        return
1761