Completed
Push — main ( a1e2ce...09bd28 )
by Angeline
16s queued 14s
created

test_Apex.TestApexMLTMethods.test_convert_to_mlt()   A

Complexity

Conditions 2

Size

Total Lines 27
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 12
nop 2
dl 0
loc 27
rs 9.8
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
@pytest.fixture()
26
def igrf_file(max_attempts=100):
27
    """A fixture for handling the coefficient file.
28
29
    Parameters
30
    ----------
31
    max_attempts : int
32
        Maximum rename attemps, needed for Windows (default=100)
33
34
    """
35
    # Ensure the coefficient file exists
36
    original_file = os.path.join(os.path.dirname(apexpy.helpers.__file__),
37
                                 'igrf13coeffs.txt')
38
    tmp_file = "temp_coeff.txt"
39
    assert os.path.isfile(original_file)
40
41
    # Move the coefficient file
42
    for _ in range(max_attempts):
43
        try:
44
            shutil.move(original_file, tmp_file)
45
            break
46
        except Exception:
47
            pass
48
    yield original_file
49
50
    # Move the coefficient file back
51
    for _ in range(max_attempts):
52
        try:
53
            shutil.move(tmp_file, original_file)
54
            break
55
        except Exception:
56
            pass
57
    return
58
59
60
def test_set_epoch_file_error(igrf_file):
61
    """Test raises OSError when IGRF coefficient file is missing."""
62
    # Test missing coefficient file failure
63
    with pytest.raises(OSError) as oerr:
64
        apexpy.Apex()
65
    error_string = "File {:} does not exist".format(igrf_file)
66
    assert str(oerr.value).startswith(error_string)
67
    return
68
69
70
class TestApexInit(object):
71
    """Test class for the Apex class object."""
72
73
    def setup_method(self):
74
        """Initialize all tests."""
75
        self.apex_out = None
76
        self.test_date = dt.datetime.utcnow()
77
        self.test_refh = 0
78
        self.bad_file = 'foo/path/to/datafile.blah'
79
80
    def teardown_method(self):
81
        """Clean up after each test."""
82
        del self.apex_out, self.test_date, self.test_refh, self.bad_file
83
84
    def eval_date(self):
85
        """Evaluate the times in self.test_date and self.apex_out."""
86
        if isinstance(self.test_date, dt.datetime) \
87
           or isinstance(self.test_date, dt.date):
88
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
89
90
        # Assert the times are the same on the order of tens of seconds.
91
        # Necessary to evaluate the current UTC
92
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
93
        return
94
95
    def eval_refh(self):
96
        """Evaluate the reference height in self.refh and self.apex_out."""
97
        eval_str = "".join(["expected reference height [",
98
                            "{:}] not equal to Apex ".format(self.test_refh),
99
                            "reference height ",
100
                            "[{:}]".format(self.apex_out.refh)])
101
        assert self.test_refh == self.apex_out.refh, eval_str
102
        return
103
104
    def test_init_defaults(self):
105
        """Test Apex class default initialization."""
106
        self.apex_out = apexpy.Apex()
107
        self.eval_date()
108
        self.eval_refh()
109
        return
110
111
    def test_init_today(self):
112
        """Test Apex class initialization with today's date."""
113
        self.apex_out = apexpy.Apex(date=self.test_date)
114
        self.eval_date()
115
        self.eval_refh()
116
        return
117
118
    @pytest.mark.parametrize("in_date",
119
                             [2015, 2015.5, dt.date(2015, 1, 1),
120
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
121
    def test_init_date(self, in_date):
122
        """Test Apex class with date initialization.
123
124
        Parameters
125
        ----------
126
        in_date : int, float, dt.date, or dt.datetime
127
            Input date in a variety of formats
128
129
        """
130
        self.test_date = in_date
131
        self.apex_out = apexpy.Apex(date=self.test_date)
132
        self.eval_date()
133
        self.eval_refh()
134
        return
135
136
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
137
    def test_set_epoch(self, new_date):
138
        """Test successful setting of Apex epoch after initialization.
139
140
        Parameters
141
        ----------
142
        new_date : int or float
143
            New date for the Apex class
144
145
        """
146
        # Evaluate the default initialization
147
        self.apex_out = apexpy.Apex()
148
        self.eval_date()
149
        self.eval_refh()
150
151
        # Update the epoch
152
        ref_apex = eval(self.apex_out.__repr__())
153
        self.apex_out.set_epoch(new_date)
154
        assert ref_apex != self.apex_out
155
        self.test_date = new_date
156
        self.eval_date()
157
        return
158
159
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
160
    def test_init_refh(self, in_refh):
161
        """Test Apex class with reference height initialization.
162
163
        Parameters
164
        ----------
165
        in_refh : float
166
            Input reference height in km
167
168
        """
169
        self.test_refh = in_refh
170
        self.apex_out = apexpy.Apex(refh=self.test_refh)
171
        self.eval_date()
172
        self.eval_refh()
173
        return
174
175
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
176
    def test_set_refh(self, new_refh):
177
        """Test the method used to set the reference height after the init.
178
179
        Parameters
180
        ----------
181
        new_refh : float
182
            Reference height in km
183
184
        """
185
        # Verify the defaults are set
186
        self.apex_out = apexpy.Apex(date=self.test_date)
187
        self.eval_date()
188
        self.eval_refh()
189
190
        # Update to a new reference height and test
191
        ref_apex = eval(self.apex_out.__repr__())
192
        self.apex_out.set_refh(new_refh)
193
194
        if self.test_refh == new_refh:
195
            assert ref_apex == self.apex_out
196
        else:
197
            assert ref_apex != self.apex_out
198
            self.test_refh = new_refh
199
        self.eval_refh()
200
        return
201
202
    def test_init_with_bad_datafile(self):
203
        """Test raises IOError with non-existent datafile input."""
204
        with pytest.raises(IOError) as oerr:
205
            apexpy.Apex(datafile=self.bad_file)
206
        assert str(oerr.value).startswith('Data file does not exist')
207
        return
208
209
    def test_init_with_bad_fortranlib(self):
210
        """Test raises IOError with non-existent datafile input."""
211
        with pytest.raises(IOError) as oerr:
212
            apexpy.Apex(fortranlib=self.bad_file)
213
        assert str(oerr.value).startswith('Fortran library does not exist')
214
        return
215
216
    def test_repr_eval(self):
217
        """Test the Apex.__repr__ results."""
218
        # Initialize the apex object
219
        self.apex_out = apexpy.Apex()
220
        self.eval_date()
221
        self.eval_refh()
222
223
        # Get and test the repr string
224
        out_str = self.apex_out.__repr__()
225
        assert out_str.find("apexpy.Apex(") == 0
226
227
        # Test the ability to re-create the apex object from the repr string
228
        new_apex = eval(out_str)
229
        assert new_apex == self.apex_out
230
        return
231
232
    def test_ne_other_class(self):
233
        """Test Apex class inequality to a different class."""
234
        self.apex_out = apexpy.Apex()
235
        self.eval_date()
236
        self.eval_refh()
237
238
        assert self.apex_out != self.test_date
239
        return
240
241
    def test_ne_missing_attr(self):
242
        """Test Apex class inequality when attributes are missing from one."""
243
        self.apex_out = apexpy.Apex()
244
        self.eval_date()
245
        self.eval_refh()
246
        ref_apex = eval(self.apex_out.__repr__())
247
        del ref_apex.RE
248
249
        assert ref_apex != self.apex_out
250
        assert self.apex_out != ref_apex
251
        return
252
253
    def test_eq_missing_attr(self):
254
        """Test Apex class equality when attributes are missing from both."""
255
        self.apex_out = apexpy.Apex()
256
        self.eval_date()
257
        self.eval_refh()
258
        ref_apex = eval(self.apex_out.__repr__())
259
        del ref_apex.RE, self.apex_out.RE
260
261
        assert ref_apex == self.apex_out
262
        return
263
264
    def test_str_eval(self):
265
        """Test the Apex.__str__ results."""
266
        # Initialize the apex object
267
        self.apex_out = apexpy.Apex()
268
        self.eval_date()
269
        self.eval_refh()
270
271
        # Get and test the printed string
272
        out_str = self.apex_out.__str__()
273
        assert out_str.find("Decimal year") > 0
274
        return
275
276
277
class TestApexMethod(object):
278
    """Test the Apex methods."""
279
    def setup_method(self):
280
        """Initialize all tests."""
281
        self.apex_out = apexpy.Apex(date=2000, refh=300)
282
        self.in_lat = 60
283
        self.in_lon = 15
284
        self.in_alt = 100
285
286
    def teardown_method(self):
287
        """Clean up after each test."""
288
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
289
290
    def get_input_args(self, method_name, precision=0.0):
291
        """Set the input arguments for the different Apex methods.
292
293
        Parameters
294
        ----------
295
        method_name : str
296
            Name of the Apex class method
297
        precision : float
298
            Value for the precision (default=0.0)
299
300
        Returns
301
        -------
302
        in_args : list
303
            List of the appropriate input arguments
304
305
        """
306
        in_args = [self.in_lat, self.in_lon, self.in_alt]
307
308
        # Add precision, if needed
309
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
310
                           "_apex2geo"]:
311
            in_args.append(precision)
312
313
        # Add a reference height, if needed
314
        if method_name in ["apxg2all"]:
315
            in_args.append(300)
316
317
        # Add a vector flag, if needed
318
        if method_name in ["apxg2all", "apxg2q"]:
319
            in_args.append(1)
320
321
        return in_args
322
323
    def test_apex_conversion_today(self):
324
        """Test Apex class conversion with today's date."""
325
        self.apex_out = apexpy.Apex(date=dt.datetime.utcnow(), refh=300)
326
        assert not np.isnan(self.apex_out.geo2apex(self.in_lat, self.in_lon,
327
                                                   self.in_alt)).any()
328
        return
329
330
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
331
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
332
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
333
                              ("_qd2geo", "apxq2g", slice(None)),
334
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
335
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
336
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
337
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
338
                                  lat, lon):
339
        """Tests Apex/fortran interface consistency for scalars.
340
341
        Parameters
342
        ----------
343
        apex_method : str
344
            Name of the Apex class method to test
345
        fortran_method : str
346
            Name of the Fortran function to test
347
        fslice : slice
348
            Slice used select the appropriate Fortran outputs
349
        lat : int or float
350
            Latitude in degrees N
351
        lon : int or float
352
            Longitude in degrees E
353
354
        """
355
        # Set the input coordinates
356
        self.in_lat = lat
357
        self.in_lon = lon
358
359
        # Get the Apex class method and the fortran function call
360
        apex_func = getattr(self.apex_out, apex_method)
361
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
362
363
        # Get the appropriate input arguments
364
        apex_args = self.get_input_args(apex_method)
365
        fortran_args = self.get_input_args(fortran_method)
366
367
        # Evaluate the equivalent function calls
368
        np.testing.assert_allclose(apex_func(*apex_args),
369
                                   fortran_func(*fortran_args)[fslice])
370
        return
371
372
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
373
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
374
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
375
                              ("_qd2geo", "apxq2g", slice(None)),
376
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
377
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
378
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
379
                                           (180, -180), (-180, 180),
380
                                           (-345, 15), (375, 15)])
381
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
382
                                        fslice, lat, lon1, lon2):
383
        """Tests Apex/fortran interface consistency for longitude rollover.
384
385
        Parameters
386
        ----------
387
        apex_method : str
388
            Name of the Apex class method to test
389
        fortran_method : str
390
            Name of the Fortran function to test
391
        fslice : slice
392
            Slice used select the appropriate Fortran outputs
393
        lat : int or float
394
            Latitude in degrees N
395
        lon1 : int or float
396
            Longitude in degrees E
397
        lon2 : int or float
398
            Equivalent longitude in degrees E
399
400
        """
401
        # Set the fixed input coordinate
402
        self.in_lat = lat
403
404
        # Get the Apex class method and the fortran function call
405
        apex_func = getattr(self.apex_out, apex_method)
406
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
407
408
        # Get the appropriate input arguments
409
        self.in_lon = lon1
410
        apex_args = self.get_input_args(apex_method)
411
412
        self.in_lon = lon2
413
        fortran_args = self.get_input_args(fortran_method)
414
415
        # Evaluate the equivalent function calls
416
        np.testing.assert_allclose(apex_func(*apex_args),
417
                                   fortran_func(*fortran_args)[fslice])
418
        return
419
420
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
421
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
422
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
423
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
424
                              ("_qd2geo", "apxq2g", slice(None)),
425
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
426
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
427
                                 fslice):
428
        """Tests Apex/fortran interface consistency for array input.
429
430
        Parameters
431
        ----------
432
        arr_shape : tuple
433
            Expected output shape
434
        apex_method : str
435
            Name of the Apex class method to test
436
        fortran_method : str
437
            Name of the Fortran function to test
438
        fslice : slice
439
            Slice used select the appropriate Fortran outputs
440
441
        """
442
        # Get the Apex class method and the fortran function call
443
        apex_func = getattr(self.apex_out, apex_method)
444
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
445
446
        # Set up the input arrays
447
        ref_lat = np.array([0, 30, 60, 90])
448
        ref_alt = np.array([100, 200, 300, 400])
449
        self.in_lat = ref_lat.reshape(arr_shape)
450
        self.in_alt = ref_alt.reshape(arr_shape)
451
        apex_args = self.get_input_args(apex_method)
452
453
        # Get the Apex class results
454
        aret = apex_func(*apex_args)
455
456
        # Get the fortran function results
457
        flats = list()
458
        flons = list()
459
460
        for i, lat in enumerate(ref_lat):
461
            self.in_lat = lat
462
            self.in_alt = ref_alt[i]
463
            fortran_args = self.get_input_args(fortran_method)
464
            fret = fortran_func(*fortran_args)[fslice]
465
            flats.append(fret[0])
466
            flons.append(fret[1])
467
468
        flats = np.array(flats)
469
        flons = np.array(flons)
470
471
        # Evaluate results
472
        try:
473
            # This returned value is array of floats
474
            np.testing.assert_allclose(aret[0].astype(float),
475
                                       flats.reshape(arr_shape).astype(float))
476
            np.testing.assert_allclose(aret[1].astype(float),
477
                                       flons.reshape(arr_shape).astype(float))
478
        except ValueError:
479
            # This returned value is array of arrays
480
            alats = aret[0].reshape((4,))
481
            alons = aret[1].reshape((4,))
482
            for i, flat in enumerate(flats):
483
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
484
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
485
486
        return
487
488
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
489
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
490
    def test_geo2apexall_scalar(self, lat, lon):
491
        """Test Apex/fortran geo2apexall interface consistency for scalars.
492
493
        Parameters
494
        ----------
495
        lat : int or float
496
            Latitude in degrees N
497
        long : int or float
498
            Longitude in degrees E
499
500
        """
501
        # Get the Apex and Fortran results
502
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
503
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
504
505
        # Evaluate each element in the results
506
        for aval, fval in zip(aret, fret):
507
            np.testing.assert_allclose(aval, fval)
508
509
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
510
    def test_geo2apexall_array(self, arr_shape):
511
        """Test Apex/fortran geo2apexall interface consistency for arrays.
512
513
        Parameters
514
        ----------
515
        arr_shape : tuple
516
            Expected output shape
517
518
        """
519
        # Set the input
520
        self.in_lat = np.array([0, 30, 60, 90])
521
        self.in_alt = np.array([100, 200, 300, 400])
522
523
        # Get the Apex class results
524
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
525
                                          self.in_lon,
526
                                          self.in_alt.reshape(arr_shape))
527
528
        # For each lat/alt pair, get the Fortran results
529
        fret = list()
530
        for i, lat in enumerate(self.in_lat):
531
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
532
                                                    self.in_alt[i], 300, 1))
533
534
        # Cycle through all returned values
535
        for i, ret in enumerate(aret):
536
            try:
537
                # This returned value is array of floats
538
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
539
                                      fret[3][i]]).reshape(arr_shape)
540
                np.testing.assert_allclose(ret.astype(float),
541
                                           fret_test.astype(float))
542
            except ValueError:
543
                # This returned value is array of arrays
544
                ret = ret.reshape((4,))
545
                for j, single_fret in enumerate(fret):
546
                    np.testing.assert_allclose(ret[j], single_fret[i])
547
        return
548
549
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
550
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
551
    def test_convert_consistency(self, in_coord, out_coord):
552
        """Test the self-consistency of the Apex convert method.
553
554
        Parameters
555
        ----------
556
        in_coord : str
557
            Input coordinate system
558
        out_coord : str
559
            Output coordinate system
560
561
        """
562
        if in_coord == out_coord:
563
            pytest.skip("Test not needed for same src and dest coordinates")
564
565
        # Define the method name
566
        method_name = "2".join([in_coord, out_coord])
567
568
        # Get the method and method inputs
569
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
570
        apex_args = self.get_input_args(method_name)
571
        apex_method = getattr(self.apex_out, method_name)
572
573
        # Define the slice needed to get equivalent output from the named method
574
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
575
576
        # Get output using convert and named method
577
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
578
                                            out_coord, **convert_kwargs)
579
        method_out = apex_method(*apex_args)[mslice]
580
581
        # Compare both outputs, should be identical
582
        np.testing.assert_allclose(convert_out, method_out)
583
        return
584
585
    @pytest.mark.parametrize("bound_lat", [90, -90])
586
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
587
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
588
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
589
        """Test the conversion at the latitude boundary, with allowed excess.
590
591
        Parameters
592
        ----------
593
        bound_lat : int or float
594
            Boundary latitude in degrees N
595
        in_coord : str
596
            Input coordinate system
597
        out_coord : str
598
            Output coordinate system
599
600
        """
601
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
602
603
        # Get the two outputs, slight tolerance outside of boundary allowed
604
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
605
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
606
607
        # Test the outputs
608
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
609
        return
610
611
    def test_convert_qd2apex_at_equator(self):
612
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
613
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
614
                                       height=320.0)
615
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
616
                                          dest='apex', height=320.0)
617
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
618
        return
619
620
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
621
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
622
    def test_convert_withnan(self, src, dest):
623
        """Test Apex.convert success with NaN input.
624
625
        Parameters
626
        ----------
627
        src : str
628
            Input coordinate system
629
        dest : str
630
            Output coordinate system
631
632
        """
633
        if src == dest:
634
            pytest.skip("Test not needed for same src and dest coordinates")
635
636
        num_nans = 5
637
        in_loc = np.arange(0, 10, dtype=float)
638
        in_loc[:num_nans] = np.nan
639
640
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
641
642
        for out in out_loc:
643
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
644
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
645
646
        return
647
648
    @pytest.mark.parametrize("bad_lat", [91, -91])
649
    def test_convert_invalid_lat(self, bad_lat):
650
        """Test convert raises ValueError for invalid latitudes.
651
652
        Parameters
653
        ----------
654
        bad_lat : int or float
655
            Latitude ouside the supported range in degrees N
656
657
        """
658
659
        with pytest.raises(ValueError) as verr:
660
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
661
662
        assert str(verr.value).find("must be in [-90, 90]") > 0
663
        return
664
665
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
666
                                        ("geo", "mlt")])
667
    def test_convert_invalid_transformation(self, coords):
668
        """Test raises NotImplementedError for bad coordinates.
669
670
        Parameters
671
        ----------
672
        coords : tuple
673
            Tuple specifying the input and output coordinate systems
674
675
        """
676
        if "mlt" in coords:
677
            estr = "datetime must be given for MLT calculations"
678
        else:
679
            estr = "Unknown coordinate transformation"
680
681
        with pytest.raises(ValueError) as verr:
682
            self.apex_out.convert(0, 0, *coords)
683
684
        assert str(verr).find(estr) >= 0
685
        return
686
687 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
688
                             [("geo2apex",
689
                               (55.94841766357422, 94.10684204101562)),
690
                              ("apex2geo",
691
                               (51.476322174072266, -66.22817993164062,
692
                                5.727287771151168e-06)),
693
                              ("geo2qd",
694
                               (56.531288146972656, 94.10684204101562)),
695
                              ("apex2qd", (60.498401178276744, 15.0)),
696
                              ("qd2apex", (59.49138097045895, 15.0))])
697
    def test_method_scalar_input(self, method_name, out_comp):
698
        """Test the user method against set values with scalars.
699
700
        Parameters
701
        ----------
702
        method_name : str
703
            Apex class method to be tested
704
        out_comp : tuple of floats
705
            Expected output values
706
707
        """
708
        # Get the desired methods
709
        user_method = getattr(self.apex_out, method_name)
710
711
        # Get the user output
712
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
713
714
        # Evaluate the user output
715
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
716
717
        for out_val in user_out:
718
            assert np.asarray(out_val).shape == (), "output is not a scalar"
719
        return
720
721
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
722
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
723
    @pytest.mark.parametrize("method_args, out_shape",
724
                             [([[60, 60], 15, 100], (2,)),
725
                              ([60, [15, 15], 100], (2,)),
726
                              ([60, 15, [100, 100]], (2,)),
727
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
728
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
729
                                    out_shape):
730
        """Test the user method with inputs that require some broadcasting.
731
732
        Parameters
733
        ----------
734
        in_coord : str
735
            Input coordiante system
736
        out_coord : str
737
            Output coordiante system
738
        method_args : list
739
            List of input arguments
740
        out_shape : tuple
741
            Expected shape of output values
742
743
        """
744
        if in_coord == out_coord:
745
            pytest.skip("Test not needed for same src and dest coordinates")
746
747
        # Get the desired methods
748
        method_name = "2".join([in_coord, out_coord])
749
        user_method = getattr(self.apex_out, method_name)
750
751
        # Get the user output
752
        user_out = user_method(*method_args)
753
754
        # Evaluate the user output
755
        for out_val in user_out:
756
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
757
            assert out_val.shape == out_shape
758
        return
759
760
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
761
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
762
    @pytest.mark.parametrize("bad_lat", [91, -91])
763
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
764
        """Test convert raises ValueError for invalid latitudes.
765
766
        Parameters
767
        ----------
768
        in_coord : str
769
            Input coordiante system
770
        out_coord : str
771
            Output coordiante system
772
        bad_lat : int
773
            Latitude in degrees N that is out of bounds
774
775
        """
776
        if in_coord == out_coord:
777
            pytest.skip("Test not needed for same src and dest coordinates")
778
779
        # Get the desired methods
780
        method_name = "2".join([in_coord, out_coord])
781
        user_method = getattr(self.apex_out, method_name)
782
783
        with pytest.raises(ValueError) as verr:
784
            user_method(bad_lat, 15, 100)
785
786
        assert str(verr.value).find("must be in [-90, 90]") > 0
787
        return
788
789
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
790
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
791
    @pytest.mark.parametrize("bound_lat", [90, -90])
792
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
793
        """Test user methods at the latitude boundary, with allowed excess.
794
795
        Parameters
796
        ----------
797
        in_coord : str
798
            Input coordiante system
799
        out_coord : str
800
            Output coordiante system
801
        bad_lat : int
802
            Latitude in degrees N that is at the limits of the boundary
803
804
        """
805
        if in_coord == out_coord:
806
            pytest.skip("Test not needed for same src and dest coordinates")
807
808
        # Get the desired methods
809
        method_name = "2".join([in_coord, out_coord])
810
        user_method = getattr(self.apex_out, method_name)
811
812
        # Get a latitude just beyond the limit
813
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
814
815
        # Get the two outputs, slight tolerance outside of boundary allowed
816
        bound_out = user_method(bound_lat, 0, 100)
817
        excess_out = user_method(excess_lat, 0, 100)
818
819
        # Test the outputs
820
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
821
        return
822
823
    def test_geo2apex_undefined_warning(self):
824
        """Test geo2apex warning and fill values for an undefined location."""
825
826
        # Update the apex object
827
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
828
829
        # Get the output and the warnings
830
        with warnings.catch_warnings(record=True) as warn_rec:
831
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
832
833
        assert np.isnan(user_lat)
834
        assert np.isfinite(user_lon)
835
        assert len(warn_rec) == 1
836
        assert issubclass(warn_rec[-1].category, UserWarning)
837
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
838
        return
839
840
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
841
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
842
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
843
        """Test quasi-dipole success with a height close to the reference.
844
845
        Parameters
846
        ----------
847
        method_name : str
848
            Apex class method name to be tested
849
        delta_h : float
850
            tolerance for height in km
851
852
        """
853
        qd_method = getattr(self.apex_out, method_name)
854
        in_args = [0, 15, self.apex_out.refh + delta_h]
855
        out_coords = qd_method(*in_args)
856
857
        for i, out_val in enumerate(out_coords):
858
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
859
        return
860
861
    @pytest.mark.parametrize("method_name, hinc, msg",
862
                             [("apex2qd", 1.0, "is > apex height"),
863
                              ("qd2apex", -1.0, "is < reference height")])
864
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
865
        """Quasi-dipole raises ApexHeightError when height above reference.
866
867
        Parameters
868
        ----------
869
        method_name : str
870
            Apex class method name to be tested
871
        hinc : float
872
            Height increment in km
873
        msg : str
874
            Expected output message
875
876
        """
877
        qd_method = getattr(self.apex_out, method_name)
878
879
        with pytest.raises(apexpy.ApexHeightError) as aerr:
880
            qd_method(0, 15, self.apex_out.refh + hinc)
881
882
        assert str(aerr).find(msg) > 0
883
        return
884
885
886
class TestApexMLTMethods(object):
887
    """Test the Apex Magnetic Local Time (MLT) methods."""
888
    def setup_method(self):
889
        """Initialize all tests."""
890
        self.apex_out = apexpy.Apex(date=2000, refh=300)
891
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
892
893
    def teardown_method(self):
894
        """Clean up after each test."""
895
        del self.apex_out, self.in_time
896
897
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
898
    def test_convert_to_mlt(self, in_coord):
899
        """Test the conversions to MLT using Apex convert.
900
901
        Parameters
902
        ----------
903
        in_coord : str
904
            Input coordinate system
905
906
        """
907
908
        # Get the magnetic longitude from the appropriate method
909
        if in_coord == "geo":
910
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
911
            mlon = apex_method(60, 15, 100)[1]
912
        else:
913
            mlon = 15
914
915
        # Get the output MLT values
916
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
917
                                            height=100, ssheight=2e5,
918
                                            datetime=self.in_time)[1]
919
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
920
921
        # Test the outputs
922
        np.testing.assert_allclose(convert_mlt, method_mlt)
923
        return
924
925
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
926
    def test_convert_mlt_to_lon(self, out_coord):
927
        """Test the conversions from MLT using Apex convert.
928
929
        Parameters
930
        ----------
931
        out_coord : str
932
            Output coordinate system
933
934
        """
935
        # Get the output longitudes
936
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
937
                                            height=100, ssheight=2e5,
938
                                            datetime=self.in_time,
939
                                            precision=1e-2)
940
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
941
942
        if out_coord == "geo":
943
            method_out = self.apex_out.apex2geo(60, mlon, 100,
944
                                                precision=1e-2)[:-1]
945
        elif out_coord == "qd":
946
            method_out = self.apex_out.apex2qd(60, mlon, 100)
947
        else:
948
            method_out = (60, mlon)
949
950
        # Evaluate the outputs
951
        np.testing.assert_allclose(convert_out, method_out)
952
        return
953
954
    def test_convert_geo2mlt_nodate(self):
955
        """Test convert from geo to MLT raises ValueError with no datetime."""
956
        with pytest.raises(ValueError):
957
            self.apex_out.convert(60, 15, 'geo', 'mlt')
958
        return
959
960
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
961
                             [({}, 23.019629923502603),
962
                              ({"ssheight": 100000}, 23.026712036132814)])
963
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
964
        """Test mlon2mlt with scalar inputs.
965
966
        Parameters
967
        ----------
968
        mlon_kwargs : dict
969
            Input kwargs
970
        test_mlt : float
971
            Output MLT in hours
972
973
        """
974
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
975
976
        np.testing.assert_allclose(mlt, test_mlt)
977
        assert np.asarray(mlt).shape == ()
978
        return
979
980
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
981
                             [({}, 14.705535888671875),
982
                              ({"ssheight": 100000}, 14.599319458007812)])
983
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
984
        """Test mlt2mlon with scalar inputs.
985
986
        Parameters
987
        ----------
988
        mlt_kwargs : dict
989
            Input kwargs
990
        test_mlon : float
991
            Output longitude in degrees E
992
993
        """
994
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
995
996
        np.testing.assert_allclose(mlon, test_mlon)
997
        assert np.asarray(mlon).shape == ()
998
        return
999
1000
    @pytest.mark.parametrize("mlon,test_mlt",
1001
                             [([0, 180], [23.019261, 11.019261]),
1002
                              (np.array([0, 180]), [23.019261, 11.019261]),
1003
                              (np.array([[0], [180]]),
1004
                               np.array([[23.019261], [11.019261]])),
1005
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
1006
                                                      [23.019261, 11.019261]]),
1007
                              (range(0, 361, 30),
1008
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
1009
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
1010
                                19.01963, 21.01963, 23.01963])])
1011
    def test_mlon2mlt_array(self, mlon, test_mlt):
1012
        """Test mlon2mlt with array inputs.
1013
1014
        Parameters
1015
        ----------
1016
        mlon : array-like
1017
            Input longitudes in degrees E
1018
        test_mlt : float
1019
            Output MLT in hours
1020
1021
        """
1022
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1023
1024
        assert mlt.shape == np.asarray(test_mlt).shape
1025
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1026
        return
1027
1028
    @pytest.mark.parametrize("mlt,test_mlon",
1029
                             [([0, 12], [14.705551, 194.705551]),
1030
                              (np.array([0, 12]), [14.705551, 194.705551]),
1031
                              (np.array([[0], [12]]),
1032
                               np.array([[14.705551], [194.705551]])),
1033
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1034
                                                    [14.705551, 194.705551]]),
1035
                              (range(0, 25, 2),
1036
                               [14.705551, 44.705551, 74.705551, 104.705551,
1037
                                134.705551, 164.705551, 194.705551, 224.705551,
1038
                                254.705551, 284.705551, 314.705551, 344.705551,
1039
                                14.705551])])
1040
    def test_mlt2mlon_array(self, mlt, test_mlon):
1041
        """Test mlt2mlon with array inputs.
1042
1043
        Parameters
1044
        ----------
1045
        mlt : array-like
1046
            Input MLT in hours
1047
        test_mlon : float
1048
            Output longitude in degrees E
1049
1050
        """
1051
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1052
1053
        assert mlon.shape == np.asarray(test_mlon).shape
1054
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1055
        return
1056
1057
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1058
    def test_mlon2mlt_diffdates(self, method_name):
1059
        """Test that MLT varies with universal time.
1060
1061
        Parameters
1062
        ----------
1063
        method_name : str
1064
            Name of Apex class method to be tested
1065
1066
        """
1067
        apex_method = getattr(self.apex_out, method_name)
1068
        mlt1 = apex_method(0, self.in_time)
1069
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1070
1071
        assert mlt1 != mlt2
1072
        return
1073
1074
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1075
    def test_mlon2mlt_offset(self, mlt_offset):
1076
        """Test the time wrapping logic for the MLT.
1077
1078
        Parameters
1079
        ----------
1080
        mlt_offset : float
1081
            MLT offset in hours
1082
1083
        """
1084
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1085
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1086
                                      self.in_time) + mlt_offset
1087
1088
        np.testing.assert_allclose(mlt1, mlt2)
1089
        return
1090
1091
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1092
    def test_mlt2mlon_offset(self, mlon_offset):
1093
        """Test the time wrapping logic for the magnetic longitude.
1094
1095
        Parameters
1096
        ----------
1097
        mlt_offset : float
1098
            MLT offset in hours
1099
1100
        """
1101
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1102
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1103
                                       self.in_time) - mlon_offset
1104
1105
        np.testing.assert_allclose(mlon1, mlon2)
1106
        return
1107
1108
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1109
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1110
    def test_convert_and_return(self, order, start_val):
1111
        """Test the conversion to magnetic longitude or MLT and back again.
1112
1113
        Parameters
1114
        ----------
1115
        order : list
1116
            List of strings specifying the order to run functions
1117
        start_val : int or float
1118
            Input value
1119
1120
        """
1121
        first_method = getattr(self.apex_out, "2".join(order))
1122
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1123
1124
        middle_val = first_method(start_val, self.in_time)
1125
        end_val = second_method(middle_val, self.in_time)
1126
1127
        np.testing.assert_allclose(start_val, end_val)
1128
        return
1129
1130
1131
class TestApexMapMethods(object):
1132
    """Test the Apex height mapping methods."""
1133
    def setup_method(self):
1134
        """Initialize all tests."""
1135
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1136
1137
    def teardown_method(self):
1138
        """Clean up after each test."""
1139
        del self.apex_out
1140
1141
    @pytest.mark.parametrize("in_args,test_mapped",
1142
                             [([60, 15, 100, 10000],
1143
                               [31.841466903686523, 17.916635513305664,
1144
                                1.7075473124350538e-6]),
1145
                              ([30, 170, 100, 500, False, 1e-2],
1146
                               [25.727270126342773, 169.60546875,
1147
                                0.00017573432705830783]),
1148
                              ([60, 15, 100, 10000, True],
1149
                               [-25.424888610839844, 27.310426712036133,
1150
                                1.2074182222931995e-6]),
1151
                              ([30, 170, 100, 500, True, 1e-2],
1152
                               [-13.76642894744873, 164.24259948730469,
1153
                                0.00056820799363777041])])
1154
    def test_map_to_height(self, in_args, test_mapped):
1155
        """Test the map_to_height function.
1156
1157
        Parameters
1158
        ----------
1159
        in_args : list
1160
            List of input arguments
1161
        test_mapped : list
1162
            List of expected outputs
1163
1164
        """
1165
        mapped = self.apex_out.map_to_height(*in_args)
1166
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1167
        return
1168
1169
    def test_map_to_height_same_height(self):
1170
        """Test the map_to_height function when mapping to same height."""
1171
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1172
                                             precision=1e-10)
1173
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1174
                                   rtol=1e-5, atol=1e-5)
1175
        return
1176
1177
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1178
    @pytest.mark.parametrize('ivec', range(0, 4))
1179
    def test_map_to_height_array_location(self, arr_shape, ivec):
1180
        """Test map_to_height with array input.
1181
1182
        Parameters
1183
        ----------
1184
        arr_shape : tuple
1185
            Expected array shape
1186
        ivec : int
1187
            Input argument index for vectorized input
1188
1189
        """
1190
        # Set the base input and output values
1191
        in_args = [60, 15, 100, 100]
1192
        test_mapped = [60, 15.00000381, 0.0]
1193
1194
        # Update inputs for one vectorized value
1195
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1196
1197
        # Calculate and test function
1198
        mapped = self.apex_out.map_to_height(*in_args)
1199
        for i, test_val in enumerate(test_mapped):
1200
            assert mapped[i].shape == arr_shape
1201
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1202
                                       atol=1e-5)
1203
        return
1204
1205
    @pytest.mark.parametrize("method_name,in_args",
1206
                             [("map_to_height", [0, 15, 100, 10000]),
1207
                              ("map_E_to_height",
1208
                               [0, 15, 100, 10000, [1, 2, 3]]),
1209
                              ("map_V_to_height",
1210
                               [0, 15, 100, 10000, [1, 2, 3]])])
1211
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1212
        """Test map_to_height raises ApexHeightError.
1213
1214
        Parameters
1215
        ----------
1216
        method_name : str
1217
            Name of the Apex class method to test
1218
        in_args : list
1219
            List of input arguments
1220
1221
        """
1222
        apex_method = getattr(self.apex_out, method_name)
1223
1224
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1225
            apex_method(*in_args)
1226
1227
        assert aerr.match("is > apex height")
1228
        return
1229
1230
    @pytest.mark.parametrize("method_name",
1231
                             ["map_E_to_height", "map_V_to_height"])
1232
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1233
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1234
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1235
        """Test height mapping of E/V with baddly shaped input raises Error.
1236
1237
        Parameters
1238
        ----------
1239
        method_name : str
1240
            Name of the Apex class method to test
1241
        ev_input : list
1242
            E/V input arguments
1243
1244
        """
1245
        apex_method = getattr(self.apex_out, method_name)
1246
        in_args = [60, 15, 100, 500, ev_input]
1247
        with pytest.raises(ValueError) as verr:
1248
            apex_method(*in_args)
1249
1250
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1251
        return
1252
1253
    def test_mapping_EV_bad_flag(self):
1254
        """Test _map_EV_to_height raises error for bad data type flag."""
1255
        with pytest.raises(ValueError) as verr:
1256
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1257
1258
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1259
        return
1260
1261
    @pytest.mark.parametrize("in_args,test_mapped",
1262
                             [([60, 15, 100, 500, [1, 2, 3]],
1263
                               [0.71152183, 2.35624876, 0.57260784]),
1264
                              ([60, 15, 100, 500, [2, 3, 4]],
1265
                               [1.56028502, 3.43916636, 0.78235384]),
1266
                              ([60, 15, 100, 1000, [1, 2, 3]],
1267
                               [0.67796492, 2.08982134, 0.55860785]),
1268
                              ([60, 15, 200, 500, [1, 2, 3]],
1269
                               [0.72377397, 2.42737471, 0.59083726]),
1270
                              ([60, 30, 100, 500, [1, 2, 3]],
1271
                               [0.68626344, 2.37530133, 0.60060124]),
1272
                              ([70, 15, 100, 500, [1, 2, 3]],
1273
                               [0.72760378, 2.18082305, 0.29141979])])
1274
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1275
        """Test mapping of E-field to a specified height.
1276
1277
        Parameters
1278
        ----------
1279
        in_args : list
1280
            List of input arguments
1281
        test_mapped : list
1282
            List of expected outputs
1283
1284
        """
1285
        mapped = self.apex_out.map_E_to_height(*in_args)
1286
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1287
        return
1288
1289
    @pytest.mark.parametrize('ev_flag, test_mapped',
1290
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1291
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1292
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1293
    @pytest.mark.parametrize('ivec', range(0, 5))
1294
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1295
                                             arr_shape, ivec):
1296
        """Test mapping of E-field/drift to a specified height with arrays.
1297
1298
        Parameters
1299
        ----------
1300
        ev_flag : str
1301
            Character flag specifying whether to run 'E' or 'V' methods
1302
        test_mapped : list
1303
            List of expected outputs
1304
        arr_shape : tuple
1305
            Shape of the expected output
1306
        ivec : int
1307
            Index of the expected output
1308
1309
        """
1310
        # Set the base input and output values
1311
        eshape = list(arr_shape)
1312
        eshape.insert(0, 3)
1313
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1314
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1315
1316
        # Update inputs for one vectorized value if this is a location input
1317
        if ivec < 4:
1318
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1319
1320
        # Get the mapped output
1321
        apex_method = getattr(self.apex_out,
1322
                              "map_{:s}_to_height".format(ev_flag))
1323
        mapped = apex_method(*in_args)
1324
1325
        # Test the results
1326
        for i, test_val in enumerate(test_mapped):
1327
            assert mapped[i].shape == arr_shape
1328
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1329
        return
1330
1331
    @pytest.mark.parametrize("in_args,test_mapped",
1332
                             [([60, 15, 100, 500, [1, 2, 3]],
1333
                               [0.81971957, 2.84512495, 0.69545001]),
1334
                              ([60, 15, 100, 500, [2, 3, 4]],
1335
                               [1.83027746, 4.14346436, 0.94764179]),
1336
                              ([60, 15, 100, 1000, [1, 2, 3]],
1337
                               [0.92457698, 3.14997661, 0.85135187]),
1338
                              ([60, 15, 200, 500, [1, 2, 3]],
1339
                               [0.80388262, 2.79321504, 0.68285158]),
1340
                              ([60, 30, 100, 500, [1, 2, 3]],
1341
                               [0.76141245, 2.87884673, 0.73655941]),
1342
                              ([70, 15, 100, 500, [1, 2, 3]],
1343
                               [0.84681866, 2.5925821, 0.34792655])])
1344
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1345
        """Test mapping of velocity to a specified height.
1346
1347
        Parameters
1348
        ----------
1349
        in_args : list
1350
            List of input arguments
1351
        test_mapped : list
1352
            List of expected outputs
1353
1354
        """
1355
        mapped = self.apex_out.map_V_to_height(*in_args)
1356
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1357
        return
1358
1359
1360
class TestApexBasevectorMethods(object):
1361
    """Test the Apex height base vector methods."""
1362
    def setup_method(self):
1363
        """Initialize all tests."""
1364
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1365
        self.lat = 60
1366
        self.lon = 15
1367
        self.height = 100
1368
        self.test_basevec = None
1369
1370
    def teardown_method(self):
1371
        """Clean up after each test."""
1372
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1373
1374
    def get_comparison_results(self, bv_coord, coords, precision):
1375
        """Get the base vector results using the hidden function for comparison.
1376
1377
        Parameters
1378
        ----------
1379
        bv_coord : str
1380
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1381
            or 'bvectors_apex'
1382
        coords : str
1383
            Expects one of 'geo', 'apex', or 'qd'
1384
        precision : float
1385
            Float specifiying precision
1386
1387
        """
1388
        if coords == "geo":
1389
            glat = self.lat
1390
            glon = self.lon
1391
        else:
1392
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1393
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1394
                                        precision=precision)
1395
1396
        if bv_coord == 'qd':
1397
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1398
        elif bv_coord == 'apex':
1399
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1400
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1401
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1402
        else:
1403
            # These are set results that need to be updated with IGRF
1404
            if coords == "geo":
1405
                self.test_basevec = (
1406
                    np.array([4.42368795e-05, 4.42368795e-05]),
1407
                    np.array([[0.01047826, 0.01047826],
1408
                              [0.33089194, 0.33089194],
1409
                              [-1.04941, -1.04941]]),
1410
                    np.array([5.3564698e-05, 5.3564698e-05]),
1411
                    np.array([[0.00865356, 0.00865356],
1412
                              [0.27327004, 0.27327004],
1413
                              [-0.8666646, -0.8666646]]))
1414
            elif coords == "apex":
1415
                self.test_basevec = (
1416
                    np.array([4.48672735e-05, 4.48672735e-05]),
1417
                    np.array([[-0.12510721, -0.12510721],
1418
                              [0.28945938, 0.28945938],
1419
                              [-1.1505738, -1.1505738]]),
1420
                    np.array([6.38577444e-05, 6.38577444e-05]),
1421
                    np.array([[-0.08790194, -0.08790194],
1422
                              [0.2033779, 0.2033779],
1423
                              [-0.808408, -0.808408]]))
1424
            else:
1425
                self.test_basevec = (
1426
                    np.array([4.46348578e-05, 4.46348578e-05]),
1427
                    np.array([[-0.12642345, -0.12642345],
1428
                              [0.29695055, 0.29695055],
1429
                              [-1.1517885, -1.1517885]]),
1430
                    np.array([6.38626285e-05, 6.38626285e-05]),
1431
                    np.array([[-0.08835986, -0.08835986],
1432
                              [0.20754464, 0.20754464],
1433
                              [-0.8050078, -0.8050078]]))
1434
1435
        return
1436
1437
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1438
    @pytest.mark.parametrize("coords,precision",
1439
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1440
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1441
        """Test the base vector calculations with scalars.
1442
1443
        Parameters
1444
        ----------
1445
        bv_coord : str
1446
            Name of the input coordinate system
1447
        coords : str
1448
            Name of the output coordinate system
1449
        precision : float
1450
            Level of run precision requested
1451
1452
        """
1453
        # Get the base vectors
1454
        base_method = getattr(self.apex_out,
1455
                              "basevectors_{:s}".format(bv_coord))
1456
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1457
                              precision=precision)
1458
        self.get_comparison_results(bv_coord, coords, precision)
1459
        if bv_coord == "apex":
1460
            basevec = list(basevec)
1461
            for i in range(4):
1462
                # Not able to compare indices 2, 3, 4, and 5
1463
                basevec.pop(2)
1464
1465
        # Test the results
1466
        for i, vec in enumerate(basevec):
1467
            np.testing.assert_allclose(vec, self.test_basevec[i])
1468
        return
1469
1470
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1471
    def test_basevectors_scalar_shape(self, bv_coord):
1472
        """Test the shape of the scalar output.
1473
1474
        Parameters
1475
        ----------
1476
        bv_coord : str
1477
            Name of the input coordinate system
1478
1479
        """
1480
        base_method = getattr(self.apex_out,
1481
                              "basevectors_{:s}".format(bv_coord))
1482
        basevec = base_method(self.lat, self.lon, self.height)
1483
1484
        for i, vec in enumerate(basevec):
1485
            if i < 2:
1486
                assert vec.shape == (2,)
1487
            else:
1488
                assert vec.shape == (3,)
1489
        return
1490
1491
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1492
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1493
    @pytest.mark.parametrize("ivec", range(3))
1494
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1495
        """Test the output shape for array inputs.
1496
1497
        Parameters
1498
        ----------
1499
        arr_shape : tuple
1500
            Expected output shape
1501
        bv_coord : str
1502
            Name of the input coordinate system
1503
        ivec : int
1504
            Index of the evaluated output value
1505
1506
        """
1507
        # Define the input arguments
1508
        in_args = [self.lat, self.lon, self.height]
1509
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1510
1511
        # Get the basevectors
1512
        base_method = getattr(self.apex_out,
1513
                              "basevectors_{:s}".format(bv_coord))
1514
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1515
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1516
        if bv_coord == "apex":
1517
            basevec = list(basevec)
1518
            for i in range(4):
1519
                # Not able to compare indices 2, 3, 4, and 5
1520
                basevec.pop(2)
1521
1522
        # Evaluate the shape and the values
1523
        for i, vec in enumerate(basevec):
1524
            test_shape = list(arr_shape)
1525
            test_shape.insert(0, 2 if i < 2 else 3)
1526
            assert vec.shape == tuple(test_shape)
1527
            assert np.all(self.test_basevec[i][0] == vec[0])
1528
            assert np.all(self.test_basevec[i][1] == vec[1])
1529
        return
1530
1531
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1532
    def test_bvectors_apex(self, coords):
1533
        """Test the bvectors_apex method.
1534
1535
        Parameters
1536
        ----------
1537
        coords : str
1538
            Name of the coordiante system
1539
1540
        """
1541
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1542
                   [self.height, self.height]]
1543
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1544
1545
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1546
                                              precision=1e-10)
1547
        for i, vec in enumerate(basevec):
1548
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1549
                                                 decimal=5)
1550
        return
1551
1552
    def test_basevectors_apex_extra_values(self):
1553
        """Test specific values in the apex base vector output."""
1554
        # Set the testing arrays
1555
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1556
                             np.array([0.939012, 0.073416, -0.07342]),
1557
                             np.array([0.055389, 1.004155, 0.257594]),
1558
                             np.array([0, 0, 1.065135])]
1559
1560
        # Get the desired output
1561
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1562
1563
        # Test the values not covered by `test_basevectors_scalar`
1564
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1565
            np.testing.assert_allclose(basevec[ibase],
1566
                                       self.test_basevec[itest], rtol=1e-4)
1567
        return
1568
1569
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1570
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1571
    def test_basevectors_apex_delta(self, lat, lon):
1572
        """Test that vectors are calculated correctly.
1573
1574
        Parameters
1575
        ----------
1576
        lat : int or float
1577
            Latitude in degrees N
1578
        lon : int or float
1579
            Longitude in degrees E
1580
1581
        """
1582
        # Get the apex base vectors and sort them for easy testing
1583
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1584
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1585
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1586
        gvec = [g1, g2, g3]
1587
        dvec = [d1, d2, d3]
1588
        evec = [e1, e2, e3]
1589
1590
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1591
            delta = 1 if idelta == jdelta else 0
1592
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1593
                                       delta, rtol=0, atol=1e-5)
1594
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1595
                                       delta, rtol=0, atol=1e-5)
1596
        return
1597
1598
    def test_basevectors_apex_invalid_scalar(self):
1599
        """Test warning and fill values for base vectors with bad inputs."""
1600
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1601
        invalid = np.full(shape=(3,), fill_value=np.nan)
1602
1603
        # Get the output and the warnings
1604
        with warnings.catch_warnings(record=True) as warn_rec:
1605
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1606
1607
        for i, bvec in enumerate(basevec):
1608
            if i < 2:
1609
                assert not np.allclose(bvec, invalid[:2])
1610
            else:
1611
                np.testing.assert_allclose(bvec, invalid)
1612
1613
        assert issubclass(warn_rec[-1].category, UserWarning)
1614
        assert 'set to NaN where' in str(warn_rec[-1].message)
1615
        return
1616
1617
1618
class TestApexGetMethods(object):
1619
    """Test the Apex `get` methods."""
1620
    def setup_method(self):
1621
        """Initialize all tests."""
1622
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1623
1624
    def teardown_method(self):
1625
        """Clean up after each test."""
1626
        del self.apex_out
1627
1628
    @pytest.mark.parametrize("alat, aheight",
1629
                             [(10, 507.409702543805),
1630
                              (60, 20313.026999999987),
1631
                              ([10, 60],
1632
                               [507.409702543805, 20313.026999999987]),
1633
                              ([[10], [60]],
1634
                               [[507.409702543805], [20313.026999999987]])])
1635
    def test_get_apex(self, alat, aheight):
1636
        """Test the apex height retrieval results.
1637
1638
        Parameters
1639
        ----------
1640
        alat : int or float
1641
            Apex latitude in degrees N
1642
        aheight : int or float
1643
            Apex height in km
1644
1645
        """
1646
        alt = self.apex_out.get_apex(alat)
1647
        np.testing.assert_allclose(alt, aheight)
1648
        return
1649
1650
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1651
                             [([80], [100], [300], 5.100682377815247e-05),
1652
                              ([80, 80], [100], [300],
1653
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1654
                              ([[80], [80]], [100], [300],
1655
                               [[5.100682377815247e-05],
1656
                                [5.100682377815247e-05]]),
1657
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1658
                               np.array([4.18657154e-05, 5.11118114e-05,
1659
                                         4.91969854e-05, 5.10519207e-05,
1660
                                         4.90054816e-05])),
1661
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1662
    def test_get_babs(self, glat, glon, height, test_bmag):
1663
        """Test the method to get the magnitude of the magnetic field.
1664
1665
        Parameters
1666
        ----------
1667
        glat : list
1668
            List of latitudes in degrees N
1669
        glon : list
1670
            List of longitudes in degrees E
1671
        height : list
1672
            List of heights in km
1673
        test_bmag : float
1674
            Expected B field magnitude
1675
1676
        """
1677
        bmag = self.apex_out.get_babs(glat, glon, height)
1678
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1679
        return
1680
1681
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1682
    def test_get_apex_with_invalid_lat(self, bad_lat):
1683
        """Test get methods raise ValueError for invalid latitudes.
1684
1685
        Parameters
1686
        ----------
1687
        bad_lat : int or float
1688
            Bad input latitude in degrees N
1689
1690
        """
1691
1692
        with pytest.raises(ValueError) as verr:
1693
            self.apex_out.get_apex(bad_lat)
1694
1695
        assert str(verr.value).find("must be in [-90, 90]") > 0
1696
        return
1697
1698
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1699
    def test_get_babs_with_invalid_lat(self, bad_lat):
1700
        """Test get methods raise ValueError for invalid latitudes.
1701
1702
        Parameters
1703
        ----------
1704
        bad_lat : int or float
1705
            Bad input latitude in degrees N
1706
1707
        """
1708
1709
        with pytest.raises(ValueError) as verr:
1710
            self.apex_out.get_babs(bad_lat, 15, 100)
1711
1712
        assert str(verr.value).find("must be in [-90, 90]") > 0
1713
        return
1714
1715
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1716
    def test_get_at_lat_boundary(self, bound_lat):
1717
        """Test get methods at the latitude boundary, with allowed excess.
1718
1719
        Parameters
1720
        ----------
1721
        bound_lat : int or float
1722
            Boundary input latitude in degrees N
1723
1724
        """
1725
        # Get a latitude just beyond the limit
1726
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1727
1728
        # Get the two outputs, slight tolerance outside of boundary allowed
1729
        bound_out = self.apex_out.get_apex(bound_lat)
1730
        excess_out = self.apex_out.get_apex(excess_lat)
1731
1732
        # Test the outputs
1733
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1734
        return
1735
1736
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1737
    def test_get_height_at_equator(self, apex_height):
1738
        """Test that `get_height` returns apex height at equator.
1739
1740
        Parameters
1741
        ----------
1742
        apex_height : float
1743
            Apex height
1744
1745
        """
1746
1747
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1748
        return
1749
1750
    @pytest.mark.parametrize("lat, height", [
1751
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1752
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1753
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1754
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1755
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1756
        (30, 657.2477500000014), (40, -871.8751821247979),
1757
        (50, -2499.1338178752017), (60, -4028.256749999999),
1758
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1759
    def test_get_height_along_fieldline(self, lat, height):
1760
        """Test that `get_height` returns expected height of field line.
1761
1762
        Parameters
1763
        ----------
1764
        lat : float
1765
            Input latitude
1766
        height : float
1767
            Output field-line height for line with apex of 3000 km
1768
1769
        """
1770
1771
        fheight = self.apex_out.get_height(lat, 3000.0)
1772
        assert abs(height - fheight) < 1.0e-7, \
1773
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1774
        return
1775
1776
1777
class TestApexMethodExtrapolateIGRF(object):
1778
    """Test the Apex methods on a year when IGRF must be extrapolated.
1779
1780
    Notes
1781
    -----
1782
    Extrapolation should be done using a year within 5 years of the latest IGRF
1783
    model epoch.
1784
1785
    """
1786
1787
    def setup_method(self):
1788
        """Initialize all tests."""
1789
        self.apex_out = apexpy.Apex(date=2025, refh=300)
1790
        self.in_lat = 60
1791
        self.in_lon = 15
1792
        self.in_alt = 100
1793
        self.in_time = dt.datetime(2024, 2, 3, 4, 5, 6)
1794
        return
1795
1796
    def teardown_method(self):
1797
        """Clean up after each test."""
1798
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
1799
        return
1800
1801 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1802
                             [("geo2apex",
1803
                               (56.25343704223633, 92.04932403564453)),
1804
                              ("apex2geo",
1805
                               (53.84184265136719, -66.93045806884766,
1806
                                3.6222547805664362e-06)),
1807
                              ("geo2qd",
1808
                               (56.82968521118164, 92.04932403564453)),
1809
                              ("apex2qd", (60.498401178276744, 15.0)),
1810
                              ("qd2apex", (59.49138097045895, 15.0))])
1811
    def test_method_scalar_input(self, method_name, out_comp):
1812
        """Test the user method against set values with scalars.
1813
1814
        Parameters
1815
        ----------
1816
        method_name : str
1817
            Apex class method to be tested
1818
        out_comp : tuple of floats
1819
            Expected output values
1820
1821
        """
1822
        # Get the desired methods
1823
        user_method = getattr(self.apex_out, method_name)
1824
1825
        # Get the user output
1826
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
1827
1828
        # Evaluate the user output
1829
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1830
1831
        for out_val in user_out:
1832
            assert np.asarray(out_val).shape == (), "output is not a scalar"
1833
        return
1834
1835
    def test_convert_to_mlt(self):
1836
        """Test conversion from mlon to mlt with scalars."""
1837
1838
        # Get user output
1839
        user_out = self.apex_out.mlon2mlt(self.in_lon, self.in_time)
1840
1841
        # Set comparison values
1842
        out_comp = 23.955474853515625
1843
1844
        # Evaluate user output
1845
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1846
        return
1847