Passed
Pull Request — main (#103)
by Angeline
01:30
created

TestApexBasevectorMethods.get_comparison_results()   B

Complexity

Conditions 6

Size

Total Lines 62
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 62
rs 7.9146
c 0
b 0
f 0
cc 6
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    """Test class for the Apex class object."""
72
73
    def setup_method(self):
74
        self.apex_out = None
75
        self.test_date = dt.datetime.utcnow()
76
        self.test_refh = 0
77
        self.bad_file = 'foo/path/to/datafile.blah'
78
79
    def teardown_method(self):
80
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
81
82
    def eval_date(self):
83
        """Evaluate the times in self.test_date and self.apex_out."""
84
        if isinstance(self.test_date, dt.datetime) \
85
           or isinstance(self.test_date, dt.date):
86
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
87
88
        # Assert the times are the same on the order of tens of seconds.
89
        # Necessary to evaluate the current UTC
90
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
91
        return
92
93
    def eval_refh(self):
94
        """Evaluate the reference height in self.refh and self.apex_out."""
95
        eval_str = "".join(["expected reference height [",
96
                            "{:}] not equal to Apex ".format(self.test_refh),
97
                            "reference height ",
98
                            "[{:}]".format(self.apex_out.refh)])
99
        assert self.test_refh == self.apex_out.refh, eval_str
100
        return
101
102
    def test_init_defaults(self):
103
        """Test Apex class default initialization."""
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
        return
108
109
    @pytest.mark.parametrize("in_date",
110
                             [2015, 2015.5, dt.date(2015, 1, 1),
111
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
112
    def test_init_date(self, in_date):
113
        """Test Apex class with date initialization.
114
115
        Parameters
116
        ----------
117
        in_date : int, float, dt.date, or dt.datetime
118
            Input date in a variety of formats
119
120
        """
121
        self.test_date = in_date
122
        self.apex_out = apexpy.Apex(date=self.test_date)
123
        self.eval_date()
124
        self.eval_refh()
125
        return
126
127
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
128
    def test_set_epoch(self, new_date):
129
        """Test successful setting of Apex epoch after initialization.
130
131
        Parameters
132
        ----------
133
        new_date : int or float
134
            New date for the Apex class
135
136
        """
137
        # Evaluate the default initialization
138
        self.apex_out = apexpy.Apex()
139
        self.eval_date()
140
        self.eval_refh()
141
142
        # Update the epoch
143
        ref_apex = eval(self.apex_out.__repr__())
144
        self.apex_out.set_epoch(new_date)
145
        assert ref_apex != self.apex_out
146
        self.test_date = new_date
147
        self.eval_date()
148
        return
149
150
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
151
    def test_init_refh(self, in_refh):
152
        """Test Apex class with reference height initialization.
153
154
        Parameters
155
        ----------
156
        in_refh : float
157
            Input reference height in km
158
159
        """
160
        self.test_refh = in_refh
161
        self.apex_out = apexpy.Apex(refh=self.test_refh)
162
        self.eval_date()
163
        self.eval_refh()
164
        return
165
166
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
167
    def test_set_refh(self, new_refh):
168
        """Test the method used to set the reference height after the init.
169
170
        Parameters
171
        ----------
172
        new_refh : float
173
            Reference height in km
174
175
        """
176
        # Verify the defaults are set
177
        self.apex_out = apexpy.Apex(date=self.test_date)
178
        self.eval_date()
179
        self.eval_refh()
180
181
        # Update to a new reference height and test
182
        ref_apex = eval(self.apex_out.__repr__())
183
        self.apex_out.set_refh(new_refh)
184
185
        if self.test_refh == new_refh:
186
            assert ref_apex == self.apex_out
187
        else:
188
            assert ref_apex != self.apex_out
189
            self.test_refh = new_refh
190
        self.eval_refh()
191
        return
192
193
    def test_init_with_bad_datafile(self):
194
        """Test raises IOError with non-existent datafile input."""
195
        with pytest.raises(IOError) as oerr:
196
            apexpy.Apex(datafile=self.bad_file)
197
        assert str(oerr.value).startswith('Data file does not exist')
198
        return
199
200
    def test_init_with_bad_fortranlib(self):
201
        """Test raises IOError with non-existent datafile input."""
202
        with pytest.raises(IOError) as oerr:
203
            apexpy.Apex(fortranlib=self.bad_file)
204
        assert str(oerr.value).startswith('Fortran library does not exist')
205
        return
206
207
    def test_repr_eval(self):
208
        """Test the Apex.__repr__ results."""
209
        # Initialize the apex object
210
        self.apex_out = apexpy.Apex()
211
        self.eval_date()
212
        self.eval_refh()
213
214
        # Get and test the repr string
215
        out_str = self.apex_out.__repr__()
216
        assert out_str.find("apexpy.Apex(") == 0
217
218
        # Test the ability to re-create the apex object from the repr string
219
        new_apex = eval(out_str)
220
        assert new_apex == self.apex_out
221
        return
222
223
    def test_ne_other_class(self):
224
        """Test Apex class inequality to a different class."""
225
        self.apex_out = apexpy.Apex()
226
        self.eval_date()
227
        self.eval_refh()
228
229
        assert self.apex_out != self.test_date
230
        return
231
232
    def test_ne_missing_attr(self):
233
        """Test Apex class inequality when attributes are missing from one."""
234
        self.apex_out = apexpy.Apex()
235
        self.eval_date()
236
        self.eval_refh()
237
        ref_apex = eval(self.apex_out.__repr__())
238
        del ref_apex.RE
239
240
        assert ref_apex != self.apex_out
241
        assert self.apex_out != ref_apex
242
        return
243
244
    def test_eq_missing_attr(self):
245
        """Test Apex class equality when attributes are missing from both."""
246
        self.apex_out = apexpy.Apex()
247
        self.eval_date()
248
        self.eval_refh()
249
        ref_apex = eval(self.apex_out.__repr__())
250
        del ref_apex.RE, self.apex_out.RE
251
252
        assert ref_apex == self.apex_out
253
        return
254
255
    def test_str_eval(self):
256
        """Test the Apex.__str__ results."""
257
        # Initialize the apex object
258
        self.apex_out = apexpy.Apex()
259
        self.eval_date()
260
        self.eval_refh()
261
262
        # Get and test the printed string
263
        out_str = self.apex_out.__str__()
264
        assert out_str.find("Decimal year") > 0
265
        return
266
267
268
class TestApexMethod(object):
269
    """Test the Apex methods."""
270
    def setup_method(self):
271
        """Initialize all tests."""
272
        self.apex_out = apexpy.Apex(date=2000, refh=300)
273
        self.in_lat = 60
274
        self.in_lon = 15
275
        self.in_alt = 100
276
277
    def teardown_method(self):
278
        """Clean up after each test."""
279
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
280
281
    def get_input_args(self, method_name, precision=0.0):
282
        """Set the input arguments for the different Apex methods.
283
284
        Parameters
285
        ----------
286
        method_name : str
287
            Name of the Apex class method
288
        precision : float
289
            Value for the precision (default=0.0)
290
291
        Returns
292
        -------
293
        in_args : list
294
            List of the appropriate input arguments
295
296
        """
297
        in_args = [self.in_lat, self.in_lon, self.in_alt]
298
299
        # Add precision, if needed
300
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
301
                           "_apex2geo"]:
302
            in_args.append(precision)
303
304
        # Add a reference height, if needed
305
        if method_name in ["apxg2all"]:
306
            in_args.append(300)
307
308
        # Add a vector flag, if needed
309
        if method_name in ["apxg2all", "apxg2q"]:
310
            in_args.append(1)
311
312
        return in_args
313
314
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
315
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
316
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
317
                              ("_qd2geo", "apxq2g", slice(None)),
318
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
319
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
320
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
321
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
322
                                  lat, lon):
323
        """Tests Apex/fortran interface consistency for scalars.
324
325
        Parameters
326
        ----------
327
        apex_method : str
328
            Name of the Apex class method to test
329
        fortran_method : str
330
            Name of the Fortran function to test
331
        fslice : slice
332
            Slice used select the appropriate Fortran outputs
333
        lat : int or float
334
            Latitude in degrees N
335
        lon : int or float
336
            Longitude in degrees E
337
338
        """
339
        # Set the input coordinates
340
        self.in_lat = lat
341
        self.in_lon = lon
342
343
        # Get the Apex class method and the fortran function call
344
        apex_func = getattr(self.apex_out, apex_method)
345
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
346
347
        # Get the appropriate input arguments
348
        apex_args = self.get_input_args(apex_method)
349
        fortran_args = self.get_input_args(fortran_method)
350
351
        # Evaluate the equivalent function calls
352
        np.testing.assert_allclose(apex_func(*apex_args),
353
                                   fortran_func(*fortran_args)[fslice])
354
        return
355
356
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
357
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
358
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
359
                              ("_qd2geo", "apxq2g", slice(None)),
360
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
361
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
362
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
363
                                           (180, -180), (-180, 180),
364
                                           (-345, 15), (375, 15)])
365
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
366
                                        fslice, lat, lon1, lon2):
367
        """Tests Apex/fortran interface consistency for longitude rollover.
368
369
        Parameters
370
        ----------
371
        apex_method : str
372
            Name of the Apex class method to test
373
        fortran_method : str
374
            Name of the Fortran function to test
375
        fslice : slice
376
            Slice used select the appropriate Fortran outputs
377
        lat : int or float
378
            Latitude in degrees N
379
        lon1 : int or float
380
            Longitude in degrees E
381
        lon2 : int or float
382
            Equivalent longitude in degrees E
383
384
        """
385
        # Set the fixed input coordinate
386
        self.in_lat = lat
387
388
        # Get the Apex class method and the fortran function call
389
        apex_func = getattr(self.apex_out, apex_method)
390
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
391
392
        # Get the appropriate input arguments
393
        self.in_lon = lon1
394
        apex_args = self.get_input_args(apex_method)
395
396
        self.in_lon = lon2
397
        fortran_args = self.get_input_args(fortran_method)
398
399
        # Evaluate the equivalent function calls
400
        np.testing.assert_allclose(apex_func(*apex_args),
401
                                   fortran_func(*fortran_args)[fslice])
402
        return
403
404
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
405
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
406
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
407
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
408
                              ("_qd2geo", "apxq2g", slice(None)),
409
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
410
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
411
                                 fslice):
412
        """Tests Apex/fortran interface consistency for array input.
413
414
        Parameters
415
        ----------
416
        arr_shape : tuple
417
            Expected output shape
418
        apex_method : str
419
            Name of the Apex class method to test
420
        fortran_method : str
421
            Name of the Fortran function to test
422
        fslice : slice
423
            Slice used select the appropriate Fortran outputs
424
425
        """
426
        # Get the Apex class method and the fortran function call
427
        apex_func = getattr(self.apex_out, apex_method)
428
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
429
430
        # Set up the input arrays
431
        ref_lat = np.array([0, 30, 60, 90])
432
        ref_alt = np.array([100, 200, 300, 400])
433
        self.in_lat = ref_lat.reshape(arr_shape)
434
        self.in_alt = ref_alt.reshape(arr_shape)
435
        apex_args = self.get_input_args(apex_method)
436
437
        # Get the Apex class results
438
        aret = apex_func(*apex_args)
439
440
        # Get the fortran function results
441
        flats = list()
442
        flons = list()
443
444
        for i, lat in enumerate(ref_lat):
445
            self.in_lat = lat
446
            self.in_alt = ref_alt[i]
447
            fortran_args = self.get_input_args(fortran_method)
448
            fret = fortran_func(*fortran_args)[fslice]
449
            flats.append(fret[0])
450
            flons.append(fret[1])
451
452
        flats = np.array(flats)
453
        flons = np.array(flons)
454
455
        # Evaluate results
456
        try:
457
            # This returned value is array of floats
458
            np.testing.assert_allclose(aret[0].astype(float),
459
                                       flats.reshape(arr_shape).astype(float))
460
            np.testing.assert_allclose(aret[1].astype(float),
461
                                       flons.reshape(arr_shape).astype(float))
462
        except ValueError:
463
            # This returned value is array of arrays
464
            alats = aret[0].reshape((4,))
465
            alons = aret[1].reshape((4,))
466
            for i, flat in enumerate(flats):
467
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
468
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
469
470
        return
471
472
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
473
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
474
    def test_geo2apexall_scalar(self, lat, lon):
475
        """Test Apex/fortran geo2apexall interface consistency for scalars.
476
477
        Parameters
478
        ----------
479
        lat : int or float
480
            Latitude in degrees N
481
        long : int or float
482
            Longitude in degrees E
483
484
        """
485
        # Get the Apex and Fortran results
486
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
487
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
488
489
        # Evaluate each element in the results
490
        for aval, fval in zip(aret, fret):
491
            np.testing.assert_allclose(aval, fval)
492
493
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
494
    def test_geo2apexall_array(self, arr_shape):
495
        """Test Apex/fortran geo2apexall interface consistency for arrays.
496
497
        Parameters
498
        ----------
499
        arr_shape : tuple
500
            Expected output shape
501
502
        """
503
        # Set the input
504
        self.in_lat = np.array([0, 30, 60, 90])
505
        self.in_alt = np.array([100, 200, 300, 400])
506
507
        # Get the Apex class results
508
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
509
                                          self.in_lon,
510
                                          self.in_alt.reshape(arr_shape))
511
512
        # For each lat/alt pair, get the Fortran results
513
        fret = list()
514
        for i, lat in enumerate(self.in_lat):
515
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
516
                                                    self.in_alt[i], 300, 1))
517
518
        # Cycle through all returned values
519
        for i, ret in enumerate(aret):
520
            try:
521
                # This returned value is array of floats
522
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
523
                                      fret[3][i]]).reshape(arr_shape)
524
                np.testing.assert_allclose(ret.astype(float),
525
                                           fret_test.astype(float))
526
            except ValueError:
527
                # This returned value is array of arrays
528
                ret = ret.reshape((4,))
529
                for j, single_fret in enumerate(fret):
530
                    np.testing.assert_allclose(ret[j], single_fret[i])
531
        return
532
533
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
534
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
535
    def test_convert_consistency(self, in_coord, out_coord):
536
        """Test the self-consistency of the Apex convert method.
537
538
        Parameters
539
        ----------
540
        in_coord : str
541
            Input coordinate system
542
        out_coord : str
543
            Output coordinate system
544
545
        """
546
        if in_coord == out_coord:
547
            pytest.skip("Test not needed for same src and dest coordinates")
548
549
        # Define the method name
550
        method_name = "2".join([in_coord, out_coord])
551
552
        # Get the method and method inputs
553
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
554
        apex_args = self.get_input_args(method_name)
555
        apex_method = getattr(self.apex_out, method_name)
556
557
        # Define the slice needed to get equivalent output from the named method
558
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
559
560
        # Get output using convert and named method
561
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
562
                                            out_coord, **convert_kwargs)
563
        method_out = apex_method(*apex_args)[mslice]
564
565
        # Compare both outputs, should be identical
566
        np.testing.assert_allclose(convert_out, method_out)
567
        return
568
569
    @pytest.mark.parametrize("bound_lat", [90, -90])
570
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
571
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
572
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
573
        """Test the conversion at the latitude boundary, with allowed excess.
574
575
        Parameters
576
        ----------
577
        bound_lat : int or float
578
            Boundary latitude in degrees N
579
        in_coord : str
580
            Input coordinate system
581
        out_coord : str
582
            Output coordinate system
583
584
        """
585
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
586
587
        # Get the two outputs, slight tolerance outside of boundary allowed
588
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
589
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
590
591
        # Test the outputs
592
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
593
        return
594
595
    def test_convert_qd2apex_at_equator(self):
596
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
597
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
598
                                       height=320.0)
599
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
600
                                          dest='apex', height=320.0)
601
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
602
        return
603
604
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
605
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
606
    def test_convert_withnan(self, src, dest):
607
        """Test Apex.convert success with NaN input.
608
609
        Parameters
610
        ----------
611
        src : str
612
            Input coordinate system
613
        dest : str
614
            Output coordinate system
615
616
        """
617
        if src == dest:
618
            pytest.skip("Test not needed for same src and dest coordinates")
619
620
        num_nans = 5
621
        in_loc = np.arange(0, 10, dtype=float)
622
        in_loc[:num_nans] = np.nan
623
624
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
625
626
        for out in out_loc:
627
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
628
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
629
630
        return
631
632
    @pytest.mark.parametrize("bad_lat", [91, -91])
633
    def test_convert_invalid_lat(self, bad_lat):
634
        """Test convert raises ValueError for invalid latitudes.
635
636
        Parameters
637
        ----------
638
        bad_lat : int or float
639
            Latitude ouside the supported range in degrees N
640
641
        """
642
643
        with pytest.raises(ValueError) as verr:
644
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
645
646
        assert str(verr.value).find("must be in [-90, 90]") > 0
647
        return
648
649
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
650
                                        ("geo", "mlt")])
651
    def test_convert_invalid_transformation(self, coords):
652
        """Test raises NotImplementedError for bad coordinates.
653
654
        Parameters
655
        ----------
656
        coords : tuple
657
            Tuple specifying the input and output coordinate systems
658
659
        """
660
        if "mlt" in coords:
661
            estr = "datetime must be given for MLT calculations"
662
        else:
663
            estr = "Unknown coordinate transformation"
664
665
        with pytest.raises(ValueError) as verr:
666
            self.apex_out.convert(0, 0, *coords)
667
668
        assert str(verr).find(estr) >= 0
669
        return
670
671
    @pytest.mark.parametrize("method_name, out_comp",
672
                             [("geo2apex",
673
                               (55.94841766357422, 94.10684204101562)),
674
                              ("apex2geo",
675
                               (51.476322174072266, -66.22817993164062,
676
                                5.727287771151168e-06)),
677
                              ("geo2qd",
678
                               (56.531288146972656, 94.10684204101562)),
679
                              ("apex2qd", (60.498401178276744, 15.0)),
680
                              ("qd2apex", (59.49138097045895, 15.0))])
681
    def test_method_scalar_input(self, method_name, out_comp):
682
        """Test the user method against set values with scalars.
683
684
        Parameters
685
        ----------
686
        method_name : str
687
            Apex class method to be tested
688
        out_comp : tuple of floats
689
            Expected output values
690
691
        """
692
        # Get the desired methods
693
        user_method = getattr(self.apex_out, method_name)
694
695
        # Get the user output
696
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
697
698
        # Evaluate the user output
699
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
700
701
        for out_val in user_out:
702
            assert np.asarray(out_val).shape == (), "output is not a scalar"
703
        return
704
705
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
706
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
707
    @pytest.mark.parametrize("method_args, out_shape",
708
                             [([[60, 60], 15, 100], (2,)),
709
                              ([60, [15, 15], 100], (2,)),
710
                              ([60, 15, [100, 100]], (2,)),
711
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
712
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
713
                                    out_shape):
714
        """Test the user method with inputs that require some broadcasting.
715
716
        Parameters
717
        ----------
718
        in_coord : str
719
            Input coordiante system
720
        out_coord : str
721
            Output coordiante system
722
        method_args : list
723
            List of input arguments
724
        out_shape : tuple
725
            Expected shape of output values
726
727
        """
728
        if in_coord == out_coord:
729
            pytest.skip("Test not needed for same src and dest coordinates")
730
731
        # Get the desired methods
732
        method_name = "2".join([in_coord, out_coord])
733
        user_method = getattr(self.apex_out, method_name)
734
735
        # Get the user output
736
        user_out = user_method(*method_args)
737
738
        # Evaluate the user output
739
        for out_val in user_out:
740
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
741
            assert out_val.shape == out_shape
742
        return
743
744
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
745
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
746
    @pytest.mark.parametrize("bad_lat", [91, -91])
747
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
748
        """Test convert raises ValueError for invalid latitudes.
749
750
        Parameters
751
        ----------
752
        in_coord : str
753
            Input coordiante system
754
        out_coord : str
755
            Output coordiante system
756
        bad_lat : int
757
            Latitude in degrees N that is out of bounds
758
759
        """
760
        if in_coord == out_coord:
761
            pytest.skip("Test not needed for same src and dest coordinates")
762
763
        # Get the desired methods
764
        method_name = "2".join([in_coord, out_coord])
765
        user_method = getattr(self.apex_out, method_name)
766
767
        with pytest.raises(ValueError) as verr:
768
            user_method(bad_lat, 15, 100)
769
770
        assert str(verr.value).find("must be in [-90, 90]") > 0
771
        return
772
773
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
774
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
775
    @pytest.mark.parametrize("bound_lat", [90, -90])
776
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
777
        """Test user methods at the latitude boundary, with allowed excess.
778
779
        Parameters
780
        ----------
781
        in_coord : str
782
            Input coordiante system
783
        out_coord : str
784
            Output coordiante system
785
        bad_lat : int
786
            Latitude in degrees N that is at the limits of the boundary
787
788
        """
789
        if in_coord == out_coord:
790
            pytest.skip("Test not needed for same src and dest coordinates")
791
792
        # Get the desired methods
793
        method_name = "2".join([in_coord, out_coord])
794
        user_method = getattr(self.apex_out, method_name)
795
796
        # Get a latitude just beyond the limit
797
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
798
799
        # Get the two outputs, slight tolerance outside of boundary allowed
800
        bound_out = user_method(bound_lat, 0, 100)
801
        excess_out = user_method(excess_lat, 0, 100)
802
803
        # Test the outputs
804
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
805
        return
806
807
    def test_geo2apex_undefined_warning(self):
808
        """Test geo2apex warning and fill values for an undefined location."""
809
810
        # Update the apex object
811
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
812
813
        # Get the output and the warnings
814
        with warnings.catch_warnings(record=True) as warn_rec:
815
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
816
817
        assert np.isnan(user_lat)
818
        assert np.isfinite(user_lon)
819
        assert len(warn_rec) == 1
820
        assert issubclass(warn_rec[-1].category, UserWarning)
821
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
822
        return
823
824
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
825
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
826
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
827
        """Test quasi-dipole success with a height close to the reference.
828
829
        Parameters
830
        ----------
831
        method_name : str
832
            Apex class method name to be tested
833
        delta_h : float
834
            tolerance for height in km
835
836
        """
837
        qd_method = getattr(self.apex_out, method_name)
838
        in_args = [0, 15, self.apex_out.refh + delta_h]
839
        out_coords = qd_method(*in_args)
840
841
        for i, out_val in enumerate(out_coords):
842
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
843
        return
844
845
    @pytest.mark.parametrize("method_name, hinc, msg",
846
                             [("apex2qd", 1.0, "is > apex height"),
847
                              ("qd2apex", -1.0, "is < reference height")])
848
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
849
        """Quasi-dipole raises ApexHeightError when height above reference.
850
851
        Parameters
852
        ----------
853
        method_name : str
854
            Apex class method name to be tested
855
        hinc : float
856
            Height increment in km
857
        msg : str
858
            Expected output message
859
860
        """
861
        qd_method = getattr(self.apex_out, method_name)
862
863
        with pytest.raises(apexpy.ApexHeightError) as aerr:
864
            qd_method(0, 15, self.apex_out.refh + hinc)
865
866
        assert str(aerr).find(msg) > 0
867
        return
868
869
870
class TestApexMLTMethods(object):
871
    """Test the Apex Magnetic Local Time (MLT) methods."""
872
    def setup_method(self):
873
        """Initialize all tests."""
874
        self.apex_out = apexpy.Apex(date=2000, refh=300)
875
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
876
877
    def teardown_method(self):
878
        """Clean up after each test."""
879
        del self.apex_out, self.in_time
880
881
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
882
    def test_convert_to_mlt(self, in_coord):
883
        """Test the conversions to MLT using Apex convert.
884
885
        Parameters
886
        ----------
887
        in_coord : str
888
            Input coordinate system
889
890
        """
891
892
        # Get the magnetic longitude from the appropriate method
893
        if in_coord == "geo":
894
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
895
            mlon = apex_method(60, 15, 100)[1]
896
        else:
897
            mlon = 15
898
899
        # Get the output MLT values
900
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
901
                                            height=100, ssheight=2e5,
902
                                            datetime=self.in_time)[1]
903
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
904
905
        # Test the outputs
906
        np.testing.assert_allclose(convert_mlt, method_mlt)
907
        return
908
909
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
910
    def test_convert_mlt_to_lon(self, out_coord):
911
        """Test the conversions from MLT using Apex convert.
912
913
        Parameters
914
        ----------
915
        out_coord : str
916
            Output coordinate system
917
918
        """
919
        # Get the output longitudes
920
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
921
                                            height=100, ssheight=2e5,
922
                                            datetime=self.in_time,
923
                                            precision=1e-2)
924
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
925
926
        if out_coord == "geo":
927
            method_out = self.apex_out.apex2geo(60, mlon, 100,
928
                                                precision=1e-2)[:-1]
929
        elif out_coord == "qd":
930
            method_out = self.apex_out.apex2qd(60, mlon, 100)
931
        else:
932
            method_out = (60, mlon)
933
934
        # Evaluate the outputs
935
        np.testing.assert_allclose(convert_out, method_out)
936
        return
937
938
    def test_convert_geo2mlt_nodate(self):
939
        """Test convert from geo to MLT raises ValueError with no datetime."""
940
        with pytest.raises(ValueError):
941
            self.apex_out.convert(60, 15, 'geo', 'mlt')
942
        return
943
944
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
945
                             [({}, 23.019629923502603),
946
                              ({"ssheight": 100000}, 23.026712036132814)])
947
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
948
        """Test mlon2mlt with scalar inputs.
949
950
        Parameters
951
        ----------
952
        mlon_kwargs : dict
953
            Input kwargs
954
        test_mlt : float
955
            Output MLT in hours
956
957
        """
958
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
959
960
        np.testing.assert_allclose(mlt, test_mlt)
961
        assert np.asarray(mlt).shape == ()
962
        return
963
964
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
965
                             [({}, 14.705535888671875),
966
                              ({"ssheight": 100000}, 14.599319458007812)])
967
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
968
        """Test mlt2mlon with scalar inputs.
969
970
        Parameters
971
        ----------
972
        mlt_kwargs : dict
973
            Input kwargs
974
        test_mlon : float
975
            Output longitude in degrees E
976
977
        """
978
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
979
980
        np.testing.assert_allclose(mlon, test_mlon)
981
        assert np.asarray(mlon).shape == ()
982
        return
983
984
    @pytest.mark.parametrize("mlon,test_mlt",
985
                             [([0, 180], [23.019261, 11.019261]),
986
                              (np.array([0, 180]), [23.019261, 11.019261]),
987
                              (np.array([[0], [180]]),
988
                               np.array([[23.019261], [11.019261]])),
989
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
990
                                                      [23.019261, 11.019261]]),
991
                              (range(0, 361, 30),
992
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
993
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
994
                                19.01963, 21.01963, 23.01963])])
995
    def test_mlon2mlt_array(self, mlon, test_mlt):
996
        """Test mlon2mlt with array inputs.
997
998
        Parameters
999
        ----------
1000
        mlon : array-like
1001
            Input longitudes in degrees E
1002
        test_mlt : float
1003
            Output MLT in hours
1004
1005
        """
1006
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1007
1008
        assert mlt.shape == np.asarray(test_mlt).shape
1009
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1010
        return
1011
1012
    @pytest.mark.parametrize("mlt,test_mlon",
1013
                             [([0, 12], [14.705551, 194.705551]),
1014
                              (np.array([0, 12]), [14.705551, 194.705551]),
1015
                              (np.array([[0], [12]]),
1016
                               np.array([[14.705551], [194.705551]])),
1017
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1018
                                                    [14.705551, 194.705551]]),
1019
                              (range(0, 25, 2),
1020
                               [14.705551, 44.705551, 74.705551, 104.705551,
1021
                                134.705551, 164.705551, 194.705551, 224.705551,
1022
                                254.705551, 284.705551, 314.705551, 344.705551,
1023
                                14.705551])])
1024
    def test_mlt2mlon_array(self, mlt, test_mlon):
1025
        """Test mlt2mlon with array inputs.
1026
1027
        Parameters
1028
        ----------
1029
        mlt : array-like
1030
            Input MLT in hours
1031
        test_mlon : float
1032
            Output longitude in degrees E
1033
1034
        """
1035
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1036
1037
        assert mlon.shape == np.asarray(test_mlon).shape
1038
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1039
        return
1040
1041
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1042
    def test_mlon2mlt_diffdates(self, method_name):
1043
        """Test that MLT varies with universal time.
1044
1045
        Parameters
1046
        ----------
1047
        method_name : str
1048
            Name of Apex class method to be tested
1049
1050
        """
1051
        apex_method = getattr(self.apex_out, method_name)
1052
        mlt1 = apex_method(0, self.in_time)
1053
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1054
1055
        assert mlt1 != mlt2
1056
        return
1057
1058
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1059
    def test_mlon2mlt_offset(self, mlt_offset):
1060
        """Test the time wrapping logic for the MLT.
1061
1062
        Parameters
1063
        ----------
1064
        mlt_offset : float
1065
            MLT offset in hours
1066
1067
        """
1068
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1069
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1070
                                      self.in_time) + mlt_offset
1071
1072
        np.testing.assert_allclose(mlt1, mlt2)
1073
        return
1074
1075
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1076
    def test_mlt2mlon_offset(self, mlon_offset):
1077
        """Test the time wrapping logic for the magnetic longitude.
1078
1079
        Parameters
1080
        ----------
1081
        mlt_offset : float
1082
            MLT offset in hours
1083
1084
        """
1085
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1086
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1087
                                       self.in_time) - mlon_offset
1088
1089
        np.testing.assert_allclose(mlon1, mlon2)
1090
        return
1091
1092
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1093
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1094
    def test_convert_and_return(self, order, start_val):
1095
        """Test the conversion to magnetic longitude or MLT and back again.
1096
1097
        Parameters
1098
        ----------
1099
        order : list
1100
            List of strings specifying the order to run functions
1101
        start_val : int or float
1102
            Input value
1103
1104
        """
1105
        first_method = getattr(self.apex_out, "2".join(order))
1106
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1107
1108
        middle_val = first_method(start_val, self.in_time)
1109
        end_val = second_method(middle_val, self.in_time)
1110
1111
        np.testing.assert_allclose(start_val, end_val)
1112
        return
1113
1114
1115
class TestApexMapMethods(object):
1116
    """Test the Apex height mapping methods."""
1117
    def setup_method(self):
1118
        """Initialize all tests."""
1119
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1120
1121
    def teardown_method(self):
1122
        """Clean up after each test."""
1123
        del self.apex_out
1124
1125
    @pytest.mark.parametrize("in_args,test_mapped",
1126
                             [([60, 15, 100, 10000],
1127
                               [31.841466903686523, 17.916635513305664,
1128
                                1.7075473124350538e-6]),
1129
                              ([30, 170, 100, 500, False, 1e-2],
1130
                               [25.727270126342773, 169.60546875,
1131
                                0.00017573432705830783]),
1132
                              ([60, 15, 100, 10000, True],
1133
                               [-25.424888610839844, 27.310426712036133,
1134
                                1.2074182222931995e-6]),
1135
                              ([30, 170, 100, 500, True, 1e-2],
1136
                               [-13.76642894744873, 164.24259948730469,
1137
                                0.00056820799363777041])])
1138
    def test_map_to_height(self, in_args, test_mapped):
1139
        """Test the map_to_height function.
1140
1141
        Parameters
1142
        ----------
1143
        in_args : list
1144
            List of input arguments
1145
        test_mapped : list
1146
            List of expected outputs
1147
1148
        """
1149
        mapped = self.apex_out.map_to_height(*in_args)
1150
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1151
        return
1152
1153
    def test_map_to_height_same_height(self):
1154
        """Test the map_to_height function when mapping to same height."""
1155
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1156
                                             precision=1e-10)
1157
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1158
                                   rtol=1e-5, atol=1e-5)
1159
        return
1160
1161
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1162
    @pytest.mark.parametrize('ivec', range(0, 4))
1163
    def test_map_to_height_array_location(self, arr_shape, ivec):
1164
        """Test map_to_height with array input.
1165
1166
        Parameters
1167
        ----------
1168
        arr_shape : tuple
1169
            Expected array shape
1170
        ivec : int
1171
            Input argument index for vectorized input
1172
1173
        """
1174
        # Set the base input and output values
1175
        in_args = [60, 15, 100, 100]
1176
        test_mapped = [60, 15.00000381, 0.0]
1177
1178
        # Update inputs for one vectorized value
1179
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1180
1181
        # Calculate and test function
1182
        mapped = self.apex_out.map_to_height(*in_args)
1183
        for i, test_val in enumerate(test_mapped):
1184
            assert mapped[i].shape == arr_shape
1185
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1186
                                       atol=1e-5)
1187
        return
1188
1189
    @pytest.mark.parametrize("method_name,in_args",
1190
                             [("map_to_height", [0, 15, 100, 10000]),
1191
                              ("map_E_to_height",
1192
                               [0, 15, 100, 10000, [1, 2, 3]]),
1193
                              ("map_V_to_height",
1194
                               [0, 15, 100, 10000, [1, 2, 3]])])
1195
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1196
        """Test map_to_height raises ApexHeightError.
1197
1198
        Parameters
1199
        ----------
1200
        method_name : str
1201
            Name of the Apex class method to test
1202
        in_args : list
1203
            List of input arguments
1204
1205
        """
1206
        apex_method = getattr(self.apex_out, method_name)
1207
1208
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1209
            apex_method(*in_args)
1210
1211
        assert aerr.match("is > apex height")
1212
        return
1213
1214
    @pytest.mark.parametrize("method_name",
1215
                             ["map_E_to_height", "map_V_to_height"])
1216
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1217
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1218
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1219
        """Test height mapping of E/V with baddly shaped input raises Error.
1220
1221
        Parameters
1222
        ----------
1223
        method_name : str
1224
            Name of the Apex class method to test
1225
        ev_input : list
1226
            E/V input arguments
1227
1228
        """
1229
        apex_method = getattr(self.apex_out, method_name)
1230
        in_args = [60, 15, 100, 500, ev_input]
1231
        with pytest.raises(ValueError) as verr:
1232
            apex_method(*in_args)
1233
1234
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1235
        return
1236
1237
    def test_mapping_EV_bad_flag(self):
1238
        """Test _map_EV_to_height raises error for bad data type flag."""
1239
        with pytest.raises(ValueError) as verr:
1240
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1241
1242
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1243
        return
1244
1245
    @pytest.mark.parametrize("in_args,test_mapped",
1246
                             [([60, 15, 100, 500, [1, 2, 3]],
1247
                               [0.71152183, 2.35624876, 0.57260784]),
1248
                              ([60, 15, 100, 500, [2, 3, 4]],
1249
                               [1.56028502, 3.43916636, 0.78235384]),
1250
                              ([60, 15, 100, 1000, [1, 2, 3]],
1251
                               [0.67796492, 2.08982134, 0.55860785]),
1252
                              ([60, 15, 200, 500, [1, 2, 3]],
1253
                               [0.72377397, 2.42737471, 0.59083726]),
1254
                              ([60, 30, 100, 500, [1, 2, 3]],
1255
                               [0.68626344, 2.37530133, 0.60060124]),
1256
                              ([70, 15, 100, 500, [1, 2, 3]],
1257
                               [0.72760378, 2.18082305, 0.29141979])])
1258
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1259
        """Test mapping of E-field to a specified height.
1260
1261
        Parameters
1262
        ----------
1263
        in_args : list
1264
            List of input arguments
1265
        test_mapped : list
1266
            List of expected outputs
1267
1268
        """
1269
        mapped = self.apex_out.map_E_to_height(*in_args)
1270
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1271
        return
1272
1273
    @pytest.mark.parametrize('ev_flag, test_mapped',
1274
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1275
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1276
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1277
    @pytest.mark.parametrize('ivec', range(0, 5))
1278
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1279
                                             arr_shape, ivec):
1280
        """Test mapping of E-field/drift to a specified height with arrays.
1281
1282
        Parameters
1283
        ----------
1284
        ev_flag : str
1285
            Character flag specifying whether to run 'E' or 'V' methods
1286
        test_mapped : list
1287
            List of expected outputs
1288
        arr_shape : tuple
1289
            Shape of the expected output
1290
        ivec : int
1291
            Index of the expected output
1292
1293
        """
1294
        # Set the base input and output values
1295
        eshape = list(arr_shape)
1296
        eshape.insert(0, 3)
1297
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1298
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1299
1300
        # Update inputs for one vectorized value if this is a location input
1301
        if ivec < 4:
1302
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1303
1304
        # Get the mapped output
1305
        apex_method = getattr(self.apex_out,
1306
                              "map_{:s}_to_height".format(ev_flag))
1307
        mapped = apex_method(*in_args)
1308
1309
        # Test the results
1310
        for i, test_val in enumerate(test_mapped):
1311
            assert mapped[i].shape == arr_shape
1312
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1313
        return
1314
1315
    @pytest.mark.parametrize("in_args,test_mapped",
1316
                             [([60, 15, 100, 500, [1, 2, 3]],
1317
                               [0.81971957, 2.84512495, 0.69545001]),
1318
                              ([60, 15, 100, 500, [2, 3, 4]],
1319
                               [1.83027746, 4.14346436, 0.94764179]),
1320
                              ([60, 15, 100, 1000, [1, 2, 3]],
1321
                               [0.92457698, 3.14997661, 0.85135187]),
1322
                              ([60, 15, 200, 500, [1, 2, 3]],
1323
                               [0.80388262, 2.79321504, 0.68285158]),
1324
                              ([60, 30, 100, 500, [1, 2, 3]],
1325
                               [0.76141245, 2.87884673, 0.73655941]),
1326
                              ([70, 15, 100, 500, [1, 2, 3]],
1327
                               [0.84681866, 2.5925821, 0.34792655])])
1328
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1329
        """Test mapping of velocity to a specified height.
1330
1331
        Parameters
1332
        ----------
1333
        in_args : list
1334
            List of input arguments
1335
        test_mapped : list
1336
            List of expected outputs
1337
1338
        """
1339
        mapped = self.apex_out.map_V_to_height(*in_args)
1340
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1341
        return
1342
1343
1344
class TestApexBasevectorMethods(object):
1345
    """Test the Apex height base vector methods."""
1346
    def setup_method(self):
1347
        """Initialize all tests."""
1348
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1349
        self.lat = 60
1350
        self.lon = 15
1351
        self.height = 100
1352
        self.test_basevec = None
1353
1354
    def teardown_method(self):
1355
        """Clean up after each test."""
1356
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1357
1358
    def get_comparison_results(self, bv_coord, coords, precision):
1359
        """Get the base vector results using the hidden function for comparison.
1360
1361
        Parameters
1362
        ----------
1363
        bv_coord : str
1364
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1365
            or 'bvectors_apex'
1366
        coords : str
1367
            Expects one of 'geo', 'apex', or 'qd'
1368
        precision : float
1369
            Float specifiying precision
1370
1371
        """
1372
        if coords == "geo":
1373
            glat = self.lat
1374
            glon = self.lon
1375
        else:
1376
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1377
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1378
                                        precision=precision)
1379
1380
        if bv_coord == 'qd':
1381
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1382
        elif bv_coord == 'apex':
1383
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1384
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1385
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1386
        else:
1387
            # These are set results that need to be updated with IGRF
1388
            if coords == "geo":
1389
                self.test_basevec = (
1390
                    np.array([4.42368795e-05, 4.42368795e-05]),
1391
                    np.array([[0.01047826, 0.01047826],
1392
                              [0.33089194, 0.33089194],
1393
                              [-1.04941, -1.04941]]),
1394
                    np.array([5.3564698e-05, 5.3564698e-05]),
1395
                    np.array([[0.00865356, 0.00865356],
1396
                              [0.27327004, 0.27327004],
1397
                              [-0.8666646, -0.8666646]]))
1398
            elif coords == "apex":
1399
                self.test_basevec = (
1400
                    np.array([4.48672735e-05, 4.48672735e-05]),
1401
                    np.array([[-0.12510721, -0.12510721],
1402
                              [0.28945938, 0.28945938],
1403
                              [-1.1505738, -1.1505738]]),
1404
                    np.array([6.38577444e-05, 6.38577444e-05]),
1405
                    np.array([[-0.08790194, -0.08790194],
1406
                              [0.2033779, 0.2033779],
1407
                              [-0.808408, -0.808408]]))
1408
            else:
1409
                self.test_basevec = (
1410
                    np.array([4.46348578e-05, 4.46348578e-05]),
1411
                    np.array([[-0.12642345, -0.12642345],
1412
                              [0.29695055, 0.29695055],
1413
                              [-1.1517885, -1.1517885]]),
1414
                    np.array([6.38626285e-05, 6.38626285e-05]),
1415
                    np.array([[-0.08835986, -0.08835986],
1416
                              [0.20754464, 0.20754464],
1417
                              [-0.8050078, -0.8050078]]))
1418
1419
        return
1420
1421
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1422
    @pytest.mark.parametrize("coords,precision",
1423
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1424
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1425
        """Test the base vector calculations with scalars.
1426
1427
        Parameters
1428
        ----------
1429
        bv_coord : str
1430
            Name of the input coordinate system
1431
        coords : str
1432
            Name of the output coordinate system
1433
        precision : float
1434
            Level of run precision requested
1435
1436
        """
1437
        # Get the base vectors
1438
        base_method = getattr(self.apex_out,
1439
                              "basevectors_{:s}".format(bv_coord))
1440
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1441
                              precision=precision)
1442
        self.get_comparison_results(bv_coord, coords, precision)
1443
        if bv_coord == "apex":
1444
            basevec = list(basevec)
1445
            for i in range(4):
1446
                # Not able to compare indices 2, 3, 4, and 5
1447
                basevec.pop(2)
1448
1449
        # Test the results
1450
        for i, vec in enumerate(basevec):
1451
            np.testing.assert_allclose(vec, self.test_basevec[i])
1452
        return
1453
1454
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1455
    def test_basevectors_scalar_shape(self, bv_coord):
1456
        """Test the shape of the scalar output.
1457
1458
        Parameters
1459
        ----------
1460
        bv_coord : str
1461
            Name of the input coordinate system
1462
1463
        """
1464
        base_method = getattr(self.apex_out,
1465
                              "basevectors_{:s}".format(bv_coord))
1466
        basevec = base_method(self.lat, self.lon, self.height)
1467
1468
        for i, vec in enumerate(basevec):
1469
            if i < 2:
1470
                assert vec.shape == (2,)
1471
            else:
1472
                assert vec.shape == (3,)
1473
        return
1474
1475
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1476
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1477
    @pytest.mark.parametrize("ivec", range(3))
1478
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1479
        """Test the output shape for array inputs.
1480
1481
        Parameters
1482
        ----------
1483
        arr_shape : tuple
1484
            Expected output shape
1485
        bv_coord : str
1486
            Name of the input coordinate system
1487
        ivec : int
1488
            Index of the evaluated output value
1489
1490
        """
1491
        # Define the input arguments
1492
        in_args = [self.lat, self.lon, self.height]
1493
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1494
1495
        # Get the basevectors
1496
        base_method = getattr(self.apex_out,
1497
                              "basevectors_{:s}".format(bv_coord))
1498
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1499
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1500
        if bv_coord == "apex":
1501
            basevec = list(basevec)
1502
            for i in range(4):
1503
                # Not able to compare indices 2, 3, 4, and 5
1504
                basevec.pop(2)
1505
1506
        # Evaluate the shape and the values
1507
        for i, vec in enumerate(basevec):
1508
            test_shape = list(arr_shape)
1509
            test_shape.insert(0, 2 if i < 2 else 3)
1510
            assert vec.shape == tuple(test_shape)
1511
            assert np.all(self.test_basevec[i][0] == vec[0])
1512
            assert np.all(self.test_basevec[i][1] == vec[1])
1513
        return
1514
1515
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1516
    def test_bvectors_apex(self, coords):
1517
        """Test the bvectors_apex method.
1518
1519
        Parameters
1520
        ----------
1521
        coords : str
1522
            Name of the coordiante system
1523
1524
        """
1525
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1526
                   [self.height, self.height]]
1527
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1528
1529
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1530
                                              precision=1e-10)
1531
        for i, vec in enumerate(basevec):
1532
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1533
                                                 decimal=5)
1534
        return
1535
1536
    def test_basevectors_apex_extra_values(self):
1537
        """Test specific values in the apex base vector output."""
1538
        # Set the testing arrays
1539
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1540
                             np.array([0.939012, 0.073416, -0.07342]),
1541
                             np.array([0.055389, 1.004155, 0.257594]),
1542
                             np.array([0, 0, 1.065135])]
1543
1544
        # Get the desired output
1545
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1546
1547
        # Test the values not covered by `test_basevectors_scalar`
1548
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1549
            np.testing.assert_allclose(basevec[ibase],
1550
                                       self.test_basevec[itest], rtol=1e-4)
1551
        return
1552
1553
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1554
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1555
    def test_basevectors_apex_delta(self, lat, lon):
1556
        """Test that vectors are calculated correctly.
1557
1558
        Parameters
1559
        ----------
1560
        lat : int or float
1561
            Latitude in degrees N
1562
        lon : int or float
1563
            Longitude in degrees E
1564
1565
        """
1566
        # Get the apex base vectors and sort them for easy testing
1567
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1568
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1569
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1570
        gvec = [g1, g2, g3]
1571
        dvec = [d1, d2, d3]
1572
        evec = [e1, e2, e3]
1573
1574
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1575
            delta = 1 if idelta == jdelta else 0
1576
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1577
                                       delta, rtol=0, atol=1e-5)
1578
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1579
                                       delta, rtol=0, atol=1e-5)
1580
        return
1581
1582
    def test_basevectors_apex_invalid_scalar(self):
1583
        """Test warning and fill values for base vectors with bad inputs."""
1584
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1585
        invalid = np.full(shape=(3,), fill_value=np.nan)
1586
1587
        # Get the output and the warnings
1588
        with warnings.catch_warnings(record=True) as warn_rec:
1589
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1590
1591
        for i, bvec in enumerate(basevec):
1592
            if i < 2:
1593
                assert not np.allclose(bvec, invalid[:2])
1594
            else:
1595
                np.testing.assert_allclose(bvec, invalid)
1596
1597
        assert issubclass(warn_rec[-1].category, UserWarning)
1598
        assert 'set to NaN where' in str(warn_rec[-1].message)
1599
        return
1600
1601
1602
class TestApexGetMethods(object):
1603
    """Test the Apex `get` methods."""
1604
    def setup_method(self):
1605
        """Initialize all tests."""
1606
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1607
1608
    def teardown_method(self):
1609
        """Clean up after each test."""
1610
        del self.apex_out
1611
1612
    @pytest.mark.parametrize("alat, aheight",
1613
                             [(10, 507.409702543805),
1614
                              (60, 20313.026999999987),
1615
                              ([10, 60],
1616
                               [507.409702543805, 20313.026999999987]),
1617
                              ([[10], [60]],
1618
                               [[507.409702543805], [20313.026999999987]])])
1619
    def test_get_apex(self, alat, aheight):
1620
        """Test the apex height retrieval results.
1621
1622
        Parameters
1623
        ----------
1624
        alat : int or float
1625
            Apex latitude in degrees N
1626
        aheight : int or float
1627
            Apex height in km
1628
1629
        """
1630
        alt = self.apex_out.get_apex(alat)
1631
        np.testing.assert_allclose(alt, aheight)
1632
        return
1633
1634
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1635
                             [([80], [100], [300], 5.100682377815247e-05),
1636
                              ([80, 80], [100], [300],
1637
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1638
                              ([[80], [80]], [100], [300],
1639
                               [[5.100682377815247e-05],
1640
                                [5.100682377815247e-05]]),
1641
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1642
                               np.array([4.18657154e-05, 5.11118114e-05,
1643
                                         4.91969854e-05, 5.10519207e-05,
1644
                                         4.90054816e-05])),
1645
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1646
    def test_get_babs(self, glat, glon, height, test_bmag):
1647
        """Test the method to get the magnitude of the magnetic field.
1648
1649
        Parameters
1650
        ----------
1651
        glat : list
1652
            List of latitudes in degrees N
1653
        glon : list
1654
            List of longitudes in degrees E
1655
        height : list
1656
            List of heights in km
1657
        test_bmag : float
1658
            Expected B field magnitude
1659
1660
        """
1661
        bmag = self.apex_out.get_babs(glat, glon, height)
1662
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1663
        return
1664
1665
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1666
    def test_get_apex_with_invalid_lat(self, bad_lat):
1667
        """Test get methods raise ValueError for invalid latitudes.
1668
1669
        Parameters
1670
        ----------
1671
        bad_lat : int or float
1672
            Bad input latitude in degrees N
1673
1674
        """
1675
1676
        with pytest.raises(ValueError) as verr:
1677
            self.apex_out.get_apex(bad_lat)
1678
1679
        assert str(verr.value).find("must be in [-90, 90]") > 0
1680
        return
1681
1682
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1683
    def test_get_babs_with_invalid_lat(self, bad_lat):
1684
        """Test get methods raise ValueError for invalid latitudes.
1685
1686
        Parameters
1687
        ----------
1688
        bad_lat : int or float
1689
            Bad input latitude in degrees N
1690
1691
        """
1692
1693
        with pytest.raises(ValueError) as verr:
1694
            self.apex_out.get_babs(bad_lat, 15, 100)
1695
1696
        assert str(verr.value).find("must be in [-90, 90]") > 0
1697
        return
1698
1699
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1700
    def test_get_at_lat_boundary(self, bound_lat):
1701
        """Test get methods at the latitude boundary, with allowed excess.
1702
1703
        Parameters
1704
        ----------
1705
        bound_lat : int or float
1706
            Boundary input latitude in degrees N
1707
1708
        """
1709
        # Get a latitude just beyond the limit
1710
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1711
1712
        # Get the two outputs, slight tolerance outside of boundary allowed
1713
        bound_out = self.apex_out.get_apex(bound_lat)
1714
        excess_out = self.apex_out.get_apex(excess_lat)
1715
1716
        # Test the outputs
1717
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1718
        return
1719
1720
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1721
    def test_get_height_at_equator(self, apex_height):
1722
        """Test that `get_height` returns apex height at equator.
1723
1724
        Parameters
1725
        ----------
1726
        apex_height : float
1727
            Apex height
1728
1729
        """
1730
1731
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1732
        return
1733
1734
    @pytest.mark.parametrize("lat, height", [
1735
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1736
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1737
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1738
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1739
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1740
        (30, 657.2477500000014), (40, -871.8751821247979),
1741
        (50, -2499.1338178752017), (60, -4028.256749999999),
1742
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1743
    def test_get_height_along_fieldline(self, lat, height):
1744
        """Test that `get_height` returns expected height of field line.
1745
1746
        Parameters
1747
        ----------
1748
        lat : float
1749
            Input latitude
1750
        height : float
1751
            Output field-line height for line with apex of 3000 km
1752
1753
        """
1754
1755
        fheight = self.apex_out.get_height(lat, 3000.0)
1756
        assert abs(height - fheight) < 1.0e-7, \
1757
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1758
        return
1759