Completed
Push — main ( 09bd28...fcb71d )
by Angeline
01:11 queued 01:03
created

test_Apex.TestApexGetMethods.teardown_method()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
class TestApexInit(object):
26
    """Test class for the Apex class object."""
27
28
    def setup_method(self):
29
        """Initialize all tests."""
30
        self.apex_out = None
31
        self.test_date = dt.datetime.utcnow()
32
        self.test_refh = 0
33
        self.bad_file = 'foo/path/to/datafile.blah'
34
        self.temp_file = None
35
36
    def teardown_method(self):
37
        """Clean up after each test."""
38
        del self.apex_out, self.test_date, self.test_refh
39
        del self.bad_file, self.temp_file
40
41
    def eval_date(self):
42
        """Evaluate the times in self.test_date and self.apex_out."""
43
        if isinstance(self.test_date, dt.datetime) \
44
           or isinstance(self.test_date, dt.date):
45
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
46
47
        # Assert the times are the same on the order of tens of seconds.
48
        # Necessary to evaluate the current UTC
49
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
50
        return
51
52
    def eval_refh(self):
53
        """Evaluate the reference height in self.refh and self.apex_out."""
54
        eval_str = "".join(["expected reference height [",
55
                            "{:}] not equal to Apex ".format(self.test_refh),
56
                            "reference height ",
57
                            "[{:}]".format(self.apex_out.refh)])
58
        assert self.test_refh == self.apex_out.refh, eval_str
59
        return
60
61
    def test_init_defaults(self):
62
        """Test Apex class default initialization."""
63
        self.apex_out = apexpy.Apex()
64
        self.eval_date()
65
        self.eval_refh()
66
        return
67
68
    def test_init_today(self):
69
        """Test Apex class initialization with today's date."""
70
        self.apex_out = apexpy.Apex(date=self.test_date)
71
        self.eval_date()
72
        self.eval_refh()
73
        return
74
75
    @pytest.mark.parametrize("in_date",
76
                             [2015, 2015.5, dt.date(2015, 1, 1),
77
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
78
    def test_init_date(self, in_date):
79
        """Test Apex class with date initialization.
80
81
        Parameters
82
        ----------
83
        in_date : int, float, dt.date, or dt.datetime
84
            Input date in a variety of formats
85
86
        """
87
        self.test_date = in_date
88
        self.apex_out = apexpy.Apex(date=self.test_date)
89
        self.eval_date()
90
        self.eval_refh()
91
        return
92
93
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
94
    def test_set_epoch(self, new_date):
95
        """Test successful setting of Apex epoch after initialization.
96
97
        Parameters
98
        ----------
99
        new_date : int or float
100
            New date for the Apex class
101
102
        """
103
        # Evaluate the default initialization
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
108
        # Update the epoch
109
        ref_apex = eval(self.apex_out.__repr__())
110
        self.apex_out.set_epoch(new_date)
111
        assert ref_apex != self.apex_out
112
        self.test_date = new_date
113
        self.eval_date()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization.
119
120
        Parameters
121
        ----------
122
        in_refh : float
123
            Input reference height in km
124
125
        """
126
        self.test_refh = in_refh
127
        self.apex_out = apexpy.Apex(refh=self.test_refh)
128
        self.eval_date()
129
        self.eval_refh()
130
        return
131
132
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
133
    def test_set_refh(self, new_refh):
134
        """Test the method used to set the reference height after the init.
135
136
        Parameters
137
        ----------
138
        new_refh : float
139
            Reference height in km
140
141
        """
142
        # Verify the defaults are set
143
        self.apex_out = apexpy.Apex(date=self.test_date)
144
        self.eval_date()
145
        self.eval_refh()
146
147
        # Update to a new reference height and test
148
        ref_apex = eval(self.apex_out.__repr__())
149
        self.apex_out.set_refh(new_refh)
150
151
        if self.test_refh == new_refh:
152
            assert ref_apex == self.apex_out
153
        else:
154
            assert ref_apex != self.apex_out
155
            self.test_refh = new_refh
156
        self.eval_refh()
157
        return
158
159
    def copy_file(self, original, max_attempts=100):
160
        """Make a temporary copy of input file."""
161
        _, ext = os.path.splitext(original)
162
        self.temp_file = 'temp' + ext
163
164
        for _ in range(max_attempts):
165
            try:
166
                shutil.copy(original, self.temp_file)
167
                break
168
            except Exception:
169
                pass
170
171
        return self.temp_file
172
173
    def test_default_datafile(self):
174
        """Test that the class initializes with the default datafile."""
175
        self.apex_out = apexpy.Apex()
176
        assert os.path.isfile(self.apex_out.datafile)
177
        return
178
179
    def test_custom_datafile(self):
180
        """Test that the class initializes with a good datafile input."""
181
182
        # Get the original datafile name
183
        apex_out_orig = apexpy.Apex()
184
        original_file = apex_out_orig.datafile
185
        del apex_out_orig
186
187
        # Create copy of datafile
188
        custom_file = self.copy_file(original_file)
189
190
        apex_out = apexpy.Apex(datafile=custom_file)
191
        assert apex_out.datafile == custom_file
192
193
        os.remove(custom_file)
194
        return
195
196
    def test_init_with_bad_datafile(self):
197
        """Test raises IOError with non-existent datafile input."""
198
        with pytest.raises(IOError) as oerr:
199
            apexpy.Apex(datafile=self.bad_file)
200
        assert str(oerr.value).startswith('Data file does not exist')
201
        return
202
203
    def test_default_fortranlib(self):
204
        """Test that the class initializes with the default fortranlib."""
205
        apex_out = apexpy.Apex()
206
        assert os.path.isfile(apex_out.fortranlib)
207
        return
208
209
    def test_custom_fortranlib(self):
210
        """Test that the class initializes with a good fortranlib input."""
211
212
        # Get the original fortranlib name
213
        apex_out_orig = apexpy.Apex()
214
        original_lib = apex_out_orig.fortranlib
215
        del apex_out_orig
216
217
        # Create copy of datafile
218
        custom_lib = self.copy_file(original_lib)
219
220
        apex_out = apexpy.Apex(fortranlib=custom_lib)
221
        assert apex_out.fortranlib == custom_lib
222
223
        os.remove(custom_lib)
224
        return
225
226
    def test_init_with_bad_fortranlib(self):
227
        """Test raises IOError with non-existent fortranlib input."""
228
        with pytest.raises(IOError) as oerr:
229
            apexpy.Apex(fortranlib=self.bad_file)
230
        assert str(oerr.value).startswith('Fortran library does not exist')
231
        return
232
233
    def test_igrf_fn(self):
234
        """Test the default igrf_fn."""
235
        apex_out = apexpy.Apex()
236
        assert os.path.isfile(apex_out.igrf_fn)
237
        return
238
239
    def test_repr_eval(self):
240
        """Test the Apex.__repr__ results."""
241
        # Initialize the apex object
242
        self.apex_out = apexpy.Apex()
243
        self.eval_date()
244
        self.eval_refh()
245
246
        # Get and test the repr string
247
        out_str = self.apex_out.__repr__()
248
        assert out_str.find("apexpy.Apex(") == 0
249
250
        # Test the ability to re-create the apex object from the repr string
251
        new_apex = eval(out_str)
252
        assert new_apex == self.apex_out
253
        return
254
255
    def test_ne_other_class(self):
256
        """Test Apex class inequality to a different class."""
257
        self.apex_out = apexpy.Apex()
258
        self.eval_date()
259
        self.eval_refh()
260
261
        assert self.apex_out != self.test_date
262
        return
263
264
    def test_ne_missing_attr(self):
265
        """Test Apex class inequality when attributes are missing from one."""
266
        self.apex_out = apexpy.Apex()
267
        self.eval_date()
268
        self.eval_refh()
269
        ref_apex = eval(self.apex_out.__repr__())
270
        del ref_apex.RE
271
272
        assert ref_apex != self.apex_out
273
        assert self.apex_out != ref_apex
274
        return
275
276
    def test_eq_missing_attr(self):
277
        """Test Apex class equality when attributes are missing from both."""
278
        self.apex_out = apexpy.Apex()
279
        self.eval_date()
280
        self.eval_refh()
281
        ref_apex = eval(self.apex_out.__repr__())
282
        del ref_apex.RE, self.apex_out.RE
283
284
        assert ref_apex == self.apex_out
285
        return
286
287
    def test_str_eval(self):
288
        """Test the Apex.__str__ results."""
289
        # Initialize the apex object
290
        self.apex_out = apexpy.Apex()
291
        self.eval_date()
292
        self.eval_refh()
293
294
        # Get and test the printed string
295
        out_str = self.apex_out.__str__()
296
        assert out_str.find("Decimal year") > 0
297
        return
298
299
300
class TestApexMethod(object):
301
    """Test the Apex methods."""
302
    def setup_method(self):
303
        """Initialize all tests."""
304
        self.apex_out = apexpy.Apex(date=2000, refh=300)
305
        self.in_lat = 60
306
        self.in_lon = 15
307
        self.in_alt = 100
308
309
    def teardown_method(self):
310
        """Clean up after each test."""
311
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
312
313
    def get_input_args(self, method_name, precision=0.0):
314
        """Set the input arguments for the different Apex methods.
315
316
        Parameters
317
        ----------
318
        method_name : str
319
            Name of the Apex class method
320
        precision : float
321
            Value for the precision (default=0.0)
322
323
        Returns
324
        -------
325
        in_args : list
326
            List of the appropriate input arguments
327
328
        """
329
        in_args = [self.in_lat, self.in_lon, self.in_alt]
330
331
        # Add precision, if needed
332
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
333
                           "_apex2geo"]:
334
            in_args.append(precision)
335
336
        # Add a reference height, if needed
337
        if method_name in ["apxg2all"]:
338
            in_args.append(300)
339
340
        # Add a vector flag, if needed
341
        if method_name in ["apxg2all", "apxg2q"]:
342
            in_args.append(1)
343
344
        return in_args
345
346
    def test_apex_conversion_today(self):
347
        """Test Apex class conversion with today's date."""
348
        self.apex_out = apexpy.Apex(date=dt.datetime.utcnow(), refh=300)
349
        assert not np.isnan(self.apex_out.geo2apex(self.in_lat, self.in_lon,
350
                                                   self.in_alt)).any()
351
        return
352
353
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
354
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
355
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
356
                              ("_qd2geo", "apxq2g", slice(None)),
357
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
358
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
359
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
360
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
361
                                  lat, lon):
362
        """Tests Apex/fortran interface consistency for scalars.
363
364
        Parameters
365
        ----------
366
        apex_method : str
367
            Name of the Apex class method to test
368
        fortran_method : str
369
            Name of the Fortran function to test
370
        fslice : slice
371
            Slice used select the appropriate Fortran outputs
372
        lat : int or float
373
            Latitude in degrees N
374
        lon : int or float
375
            Longitude in degrees E
376
377
        """
378
        # Set the input coordinates
379
        self.in_lat = lat
380
        self.in_lon = lon
381
382
        # Get the Apex class method and the fortran function call
383
        apex_func = getattr(self.apex_out, apex_method)
384
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
385
386
        # Get the appropriate input arguments
387
        apex_args = self.get_input_args(apex_method)
388
        fortran_args = self.get_input_args(fortran_method)
389
390
        # Evaluate the equivalent function calls
391
        np.testing.assert_allclose(apex_func(*apex_args),
392
                                   fortran_func(*fortran_args)[fslice])
393
        return
394
395
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
396
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
397
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
398
                              ("_qd2geo", "apxq2g", slice(None)),
399
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
400
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
401
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
402
                                           (180, -180), (-180, 180),
403
                                           (-345, 15), (375, 15)])
404
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
405
                                        fslice, lat, lon1, lon2):
406
        """Tests Apex/fortran interface consistency for longitude rollover.
407
408
        Parameters
409
        ----------
410
        apex_method : str
411
            Name of the Apex class method to test
412
        fortran_method : str
413
            Name of the Fortran function to test
414
        fslice : slice
415
            Slice used select the appropriate Fortran outputs
416
        lat : int or float
417
            Latitude in degrees N
418
        lon1 : int or float
419
            Longitude in degrees E
420
        lon2 : int or float
421
            Equivalent longitude in degrees E
422
423
        """
424
        # Set the fixed input coordinate
425
        self.in_lat = lat
426
427
        # Get the Apex class method and the fortran function call
428
        apex_func = getattr(self.apex_out, apex_method)
429
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
430
431
        # Get the appropriate input arguments
432
        self.in_lon = lon1
433
        apex_args = self.get_input_args(apex_method)
434
435
        self.in_lon = lon2
436
        fortran_args = self.get_input_args(fortran_method)
437
438
        # Evaluate the equivalent function calls
439
        np.testing.assert_allclose(apex_func(*apex_args),
440
                                   fortran_func(*fortran_args)[fslice])
441
        return
442
443
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
444
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
445
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
446
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
447
                              ("_qd2geo", "apxq2g", slice(None)),
448
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
449
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
450
                                 fslice):
451
        """Tests Apex/fortran interface consistency for array input.
452
453
        Parameters
454
        ----------
455
        arr_shape : tuple
456
            Expected output shape
457
        apex_method : str
458
            Name of the Apex class method to test
459
        fortran_method : str
460
            Name of the Fortran function to test
461
        fslice : slice
462
            Slice used select the appropriate Fortran outputs
463
464
        """
465
        # Get the Apex class method and the fortran function call
466
        apex_func = getattr(self.apex_out, apex_method)
467
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
468
469
        # Set up the input arrays
470
        ref_lat = np.array([0, 30, 60, 90])
471
        ref_alt = np.array([100, 200, 300, 400])
472
        self.in_lat = ref_lat.reshape(arr_shape)
473
        self.in_alt = ref_alt.reshape(arr_shape)
474
        apex_args = self.get_input_args(apex_method)
475
476
        # Get the Apex class results
477
        aret = apex_func(*apex_args)
478
479
        # Get the fortran function results
480
        flats = list()
481
        flons = list()
482
483
        for i, lat in enumerate(ref_lat):
484
            self.in_lat = lat
485
            self.in_alt = ref_alt[i]
486
            fortran_args = self.get_input_args(fortran_method)
487
            fret = fortran_func(*fortran_args)[fslice]
488
            flats.append(fret[0])
489
            flons.append(fret[1])
490
491
        flats = np.array(flats)
492
        flons = np.array(flons)
493
494
        # Evaluate results
495
        try:
496
            # This returned value is array of floats
497
            np.testing.assert_allclose(aret[0].astype(float),
498
                                       flats.reshape(arr_shape).astype(float))
499
            np.testing.assert_allclose(aret[1].astype(float),
500
                                       flons.reshape(arr_shape).astype(float))
501
        except ValueError:
502
            # This returned value is array of arrays
503
            alats = aret[0].reshape((4,))
504
            alons = aret[1].reshape((4,))
505
            for i, flat in enumerate(flats):
506
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
507
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
508
509
        return
510
511
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
512
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
513
    def test_geo2apexall_scalar(self, lat, lon):
514
        """Test Apex/fortran geo2apexall interface consistency for scalars.
515
516
        Parameters
517
        ----------
518
        lat : int or float
519
            Latitude in degrees N
520
        long : int or float
521
            Longitude in degrees E
522
523
        """
524
        # Get the Apex and Fortran results
525
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
526
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
527
528
        # Evaluate each element in the results
529
        for aval, fval in zip(aret, fret):
530
            np.testing.assert_allclose(aval, fval)
531
532
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
533
    def test_geo2apexall_array(self, arr_shape):
534
        """Test Apex/fortran geo2apexall interface consistency for arrays.
535
536
        Parameters
537
        ----------
538
        arr_shape : tuple
539
            Expected output shape
540
541
        """
542
        # Set the input
543
        self.in_lat = np.array([0, 30, 60, 90])
544
        self.in_alt = np.array([100, 200, 300, 400])
545
546
        # Get the Apex class results
547
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
548
                                          self.in_lon,
549
                                          self.in_alt.reshape(arr_shape))
550
551
        # For each lat/alt pair, get the Fortran results
552
        fret = list()
553
        for i, lat in enumerate(self.in_lat):
554
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
555
                                                    self.in_alt[i], 300, 1))
556
557
        # Cycle through all returned values
558
        for i, ret in enumerate(aret):
559
            try:
560
                # This returned value is array of floats
561
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
562
                                      fret[3][i]]).reshape(arr_shape)
563
                np.testing.assert_allclose(ret.astype(float),
564
                                           fret_test.astype(float))
565
            except ValueError:
566
                # This returned value is array of arrays
567
                ret = ret.reshape((4,))
568
                for j, single_fret in enumerate(fret):
569
                    np.testing.assert_allclose(ret[j], single_fret[i])
570
        return
571
572
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
573
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
574
    def test_convert_consistency(self, in_coord, out_coord):
575
        """Test the self-consistency of the Apex convert method.
576
577
        Parameters
578
        ----------
579
        in_coord : str
580
            Input coordinate system
581
        out_coord : str
582
            Output coordinate system
583
584
        """
585
        if in_coord == out_coord:
586
            pytest.skip("Test not needed for same src and dest coordinates")
587
588
        # Define the method name
589
        method_name = "2".join([in_coord, out_coord])
590
591
        # Get the method and method inputs
592
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
593
        apex_args = self.get_input_args(method_name)
594
        apex_method = getattr(self.apex_out, method_name)
595
596
        # Define the slice needed to get equivalent output from the named method
597
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
598
599
        # Get output using convert and named method
600
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
601
                                            out_coord, **convert_kwargs)
602
        method_out = apex_method(*apex_args)[mslice]
603
604
        # Compare both outputs, should be identical
605
        np.testing.assert_allclose(convert_out, method_out)
606
        return
607
608
    @pytest.mark.parametrize("bound_lat", [90, -90])
609
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
610
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
611
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
612
        """Test the conversion at the latitude boundary, with allowed excess.
613
614
        Parameters
615
        ----------
616
        bound_lat : int or float
617
            Boundary latitude in degrees N
618
        in_coord : str
619
            Input coordinate system
620
        out_coord : str
621
            Output coordinate system
622
623
        """
624
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
625
626
        # Get the two outputs, slight tolerance outside of boundary allowed
627
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
628
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
629
630
        # Test the outputs
631
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
632
        return
633
634
    def test_convert_qd2apex_at_equator(self):
635
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
636
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
637
                                       height=320.0)
638
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
639
                                          dest='apex', height=320.0)
640
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
641
        return
642
643
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
644
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
645
    def test_convert_withnan(self, src, dest):
646
        """Test Apex.convert success with NaN input.
647
648
        Parameters
649
        ----------
650
        src : str
651
            Input coordinate system
652
        dest : str
653
            Output coordinate system
654
655
        """
656
        if src == dest:
657
            pytest.skip("Test not needed for same src and dest coordinates")
658
659
        num_nans = 5
660
        in_loc = np.arange(0, 10, dtype=float)
661
        in_loc[:num_nans] = np.nan
662
663
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
664
665
        for out in out_loc:
666
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
667
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
668
669
        return
670
671
    @pytest.mark.parametrize("bad_lat", [91, -91])
672
    def test_convert_invalid_lat(self, bad_lat):
673
        """Test convert raises ValueError for invalid latitudes.
674
675
        Parameters
676
        ----------
677
        bad_lat : int or float
678
            Latitude ouside the supported range in degrees N
679
680
        """
681
682
        with pytest.raises(ValueError) as verr:
683
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
684
685
        assert str(verr.value).find("must be in [-90, 90]") > 0
686
        return
687
688
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
689
                                        ("geo", "mlt")])
690
    def test_convert_invalid_transformation(self, coords):
691
        """Test raises NotImplementedError for bad coordinates.
692
693
        Parameters
694
        ----------
695
        coords : tuple
696
            Tuple specifying the input and output coordinate systems
697
698
        """
699
        if "mlt" in coords:
700
            estr = "datetime must be given for MLT calculations"
701
        else:
702
            estr = "Unknown coordinate transformation"
703
704
        with pytest.raises(ValueError) as verr:
705
            self.apex_out.convert(0, 0, *coords)
706
707
        assert str(verr).find(estr) >= 0
708
        return
709
710 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
711
                             [("geo2apex",
712
                               (55.94841766357422, 94.10684204101562)),
713
                              ("apex2geo",
714
                               (51.476322174072266, -66.22817993164062,
715
                                5.727287771151168e-06)),
716
                              ("geo2qd",
717
                               (56.531288146972656, 94.10684204101562)),
718
                              ("apex2qd", (60.498401178276744, 15.0)),
719
                              ("qd2apex", (59.49138097045895, 15.0))])
720
    def test_method_scalar_input(self, method_name, out_comp):
721
        """Test the user method against set values with scalars.
722
723
        Parameters
724
        ----------
725
        method_name : str
726
            Apex class method to be tested
727
        out_comp : tuple of floats
728
            Expected output values
729
730
        """
731
        # Get the desired methods
732
        user_method = getattr(self.apex_out, method_name)
733
734
        # Get the user output
735
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
736
737
        # Evaluate the user output
738
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
739
740
        for out_val in user_out:
741
            assert np.asarray(out_val).shape == (), "output is not a scalar"
742
        return
743
744
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
745
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
746
    @pytest.mark.parametrize("method_args, out_shape",
747
                             [([[60, 60], 15, 100], (2,)),
748
                              ([60, [15, 15], 100], (2,)),
749
                              ([60, 15, [100, 100]], (2,)),
750
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
751
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
752
                                    out_shape):
753
        """Test the user method with inputs that require some broadcasting.
754
755
        Parameters
756
        ----------
757
        in_coord : str
758
            Input coordiante system
759
        out_coord : str
760
            Output coordiante system
761
        method_args : list
762
            List of input arguments
763
        out_shape : tuple
764
            Expected shape of output values
765
766
        """
767
        if in_coord == out_coord:
768
            pytest.skip("Test not needed for same src and dest coordinates")
769
770
        # Get the desired methods
771
        method_name = "2".join([in_coord, out_coord])
772
        user_method = getattr(self.apex_out, method_name)
773
774
        # Get the user output
775
        user_out = user_method(*method_args)
776
777
        # Evaluate the user output
778
        for out_val in user_out:
779
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
780
            assert out_val.shape == out_shape
781
        return
782
783
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
784
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
785
    @pytest.mark.parametrize("bad_lat", [91, -91])
786
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
787
        """Test convert raises ValueError for invalid latitudes.
788
789
        Parameters
790
        ----------
791
        in_coord : str
792
            Input coordiante system
793
        out_coord : str
794
            Output coordiante system
795
        bad_lat : int
796
            Latitude in degrees N that is out of bounds
797
798
        """
799
        if in_coord == out_coord:
800
            pytest.skip("Test not needed for same src and dest coordinates")
801
802
        # Get the desired methods
803
        method_name = "2".join([in_coord, out_coord])
804
        user_method = getattr(self.apex_out, method_name)
805
806
        with pytest.raises(ValueError) as verr:
807
            user_method(bad_lat, 15, 100)
808
809
        assert str(verr.value).find("must be in [-90, 90]") > 0
810
        return
811
812
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
813
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
814
    @pytest.mark.parametrize("bound_lat", [90, -90])
815
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
816
        """Test user methods at the latitude boundary, with allowed excess.
817
818
        Parameters
819
        ----------
820
        in_coord : str
821
            Input coordiante system
822
        out_coord : str
823
            Output coordiante system
824
        bad_lat : int
825
            Latitude in degrees N that is at the limits of the boundary
826
827
        """
828
        if in_coord == out_coord:
829
            pytest.skip("Test not needed for same src and dest coordinates")
830
831
        # Get the desired methods
832
        method_name = "2".join([in_coord, out_coord])
833
        user_method = getattr(self.apex_out, method_name)
834
835
        # Get a latitude just beyond the limit
836
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
837
838
        # Get the two outputs, slight tolerance outside of boundary allowed
839
        bound_out = user_method(bound_lat, 0, 100)
840
        excess_out = user_method(excess_lat, 0, 100)
841
842
        # Test the outputs
843
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
844
        return
845
846
    def test_geo2apex_undefined_warning(self):
847
        """Test geo2apex warning and fill values for an undefined location."""
848
849
        # Update the apex object
850
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
851
852
        # Get the output and the warnings
853
        with warnings.catch_warnings(record=True) as warn_rec:
854
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
855
856
        assert np.isnan(user_lat)
857
        assert np.isfinite(user_lon)
858
        assert len(warn_rec) == 1
859
        assert issubclass(warn_rec[-1].category, UserWarning)
860
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
861
        return
862
863
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
864
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
865
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
866
        """Test quasi-dipole success with a height close to the reference.
867
868
        Parameters
869
        ----------
870
        method_name : str
871
            Apex class method name to be tested
872
        delta_h : float
873
            tolerance for height in km
874
875
        """
876
        qd_method = getattr(self.apex_out, method_name)
877
        in_args = [0, 15, self.apex_out.refh + delta_h]
878
        out_coords = qd_method(*in_args)
879
880
        for i, out_val in enumerate(out_coords):
881
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
882
        return
883
884
    @pytest.mark.parametrize("method_name, hinc, msg",
885
                             [("apex2qd", 1.0, "is > apex height"),
886
                              ("qd2apex", -1.0, "is < reference height")])
887
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
888
        """Quasi-dipole raises ApexHeightError when height above reference.
889
890
        Parameters
891
        ----------
892
        method_name : str
893
            Apex class method name to be tested
894
        hinc : float
895
            Height increment in km
896
        msg : str
897
            Expected output message
898
899
        """
900
        qd_method = getattr(self.apex_out, method_name)
901
902
        with pytest.raises(apexpy.ApexHeightError) as aerr:
903
            qd_method(0, 15, self.apex_out.refh + hinc)
904
905
        assert str(aerr).find(msg) > 0
906
        return
907
908
909
class TestApexMLTMethods(object):
910
    """Test the Apex Magnetic Local Time (MLT) methods."""
911
    def setup_method(self):
912
        """Initialize all tests."""
913
        self.apex_out = apexpy.Apex(date=2000, refh=300)
914
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
915
916
    def teardown_method(self):
917
        """Clean up after each test."""
918
        del self.apex_out, self.in_time
919
920
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
921
    def test_convert_to_mlt(self, in_coord):
922
        """Test the conversions to MLT using Apex convert.
923
924
        Parameters
925
        ----------
926
        in_coord : str
927
            Input coordinate system
928
929
        """
930
931
        # Get the magnetic longitude from the appropriate method
932
        if in_coord == "geo":
933
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
934
            mlon = apex_method(60, 15, 100)[1]
935
        else:
936
            mlon = 15
937
938
        # Get the output MLT values
939
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
940
                                            height=100, ssheight=2e5,
941
                                            datetime=self.in_time)[1]
942
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
943
944
        # Test the outputs
945
        np.testing.assert_allclose(convert_mlt, method_mlt)
946
        return
947
948
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
949
    def test_convert_mlt_to_lon(self, out_coord):
950
        """Test the conversions from MLT using Apex convert.
951
952
        Parameters
953
        ----------
954
        out_coord : str
955
            Output coordinate system
956
957
        """
958
        # Get the output longitudes
959
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
960
                                            height=100, ssheight=2e5,
961
                                            datetime=self.in_time,
962
                                            precision=1e-2)
963
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
964
965
        if out_coord == "geo":
966
            method_out = self.apex_out.apex2geo(60, mlon, 100,
967
                                                precision=1e-2)[:-1]
968
        elif out_coord == "qd":
969
            method_out = self.apex_out.apex2qd(60, mlon, 100)
970
        else:
971
            method_out = (60, mlon)
972
973
        # Evaluate the outputs
974
        np.testing.assert_allclose(convert_out, method_out)
975
        return
976
977
    def test_convert_geo2mlt_nodate(self):
978
        """Test convert from geo to MLT raises ValueError with no datetime."""
979
        with pytest.raises(ValueError):
980
            self.apex_out.convert(60, 15, 'geo', 'mlt')
981
        return
982
983
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
984
                             [({}, 23.019629923502603),
985
                              ({"ssheight": 100000}, 23.026712036132814)])
986
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
987
        """Test mlon2mlt with scalar inputs.
988
989
        Parameters
990
        ----------
991
        mlon_kwargs : dict
992
            Input kwargs
993
        test_mlt : float
994
            Output MLT in hours
995
996
        """
997
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
998
999
        np.testing.assert_allclose(mlt, test_mlt)
1000
        assert np.asarray(mlt).shape == ()
1001
        return
1002
1003
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
1004
                             [({}, 14.705535888671875),
1005
                              ({"ssheight": 100000}, 14.599319458007812)])
1006
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
1007
        """Test mlt2mlon with scalar inputs.
1008
1009
        Parameters
1010
        ----------
1011
        mlt_kwargs : dict
1012
            Input kwargs
1013
        test_mlon : float
1014
            Output longitude in degrees E
1015
1016
        """
1017
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
1018
1019
        np.testing.assert_allclose(mlon, test_mlon)
1020
        assert np.asarray(mlon).shape == ()
1021
        return
1022
1023
    @pytest.mark.parametrize("mlon,test_mlt",
1024
                             [([0, 180], [23.019261, 11.019261]),
1025
                              (np.array([0, 180]), [23.019261, 11.019261]),
1026
                              (np.array([[0], [180]]),
1027
                               np.array([[23.019261], [11.019261]])),
1028
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
1029
                                                      [23.019261, 11.019261]]),
1030
                              (range(0, 361, 30),
1031
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
1032
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
1033
                                19.01963, 21.01963, 23.01963])])
1034
    def test_mlon2mlt_array(self, mlon, test_mlt):
1035
        """Test mlon2mlt with array inputs.
1036
1037
        Parameters
1038
        ----------
1039
        mlon : array-like
1040
            Input longitudes in degrees E
1041
        test_mlt : float
1042
            Output MLT in hours
1043
1044
        """
1045
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1046
1047
        assert mlt.shape == np.asarray(test_mlt).shape
1048
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1049
        return
1050
1051
    @pytest.mark.parametrize("mlt,test_mlon",
1052
                             [([0, 12], [14.705551, 194.705551]),
1053
                              (np.array([0, 12]), [14.705551, 194.705551]),
1054
                              (np.array([[0], [12]]),
1055
                               np.array([[14.705551], [194.705551]])),
1056
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1057
                                                    [14.705551, 194.705551]]),
1058
                              (range(0, 25, 2),
1059
                               [14.705551, 44.705551, 74.705551, 104.705551,
1060
                                134.705551, 164.705551, 194.705551, 224.705551,
1061
                                254.705551, 284.705551, 314.705551, 344.705551,
1062
                                14.705551])])
1063
    def test_mlt2mlon_array(self, mlt, test_mlon):
1064
        """Test mlt2mlon with array inputs.
1065
1066
        Parameters
1067
        ----------
1068
        mlt : array-like
1069
            Input MLT in hours
1070
        test_mlon : float
1071
            Output longitude in degrees E
1072
1073
        """
1074
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1075
1076
        assert mlon.shape == np.asarray(test_mlon).shape
1077
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1078
        return
1079
1080
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1081
    def test_mlon2mlt_diffdates(self, method_name):
1082
        """Test that MLT varies with universal time.
1083
1084
        Parameters
1085
        ----------
1086
        method_name : str
1087
            Name of Apex class method to be tested
1088
1089
        """
1090
        apex_method = getattr(self.apex_out, method_name)
1091
        mlt1 = apex_method(0, self.in_time)
1092
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1093
1094
        assert mlt1 != mlt2
1095
        return
1096
1097
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1098
    def test_mlon2mlt_offset(self, mlt_offset):
1099
        """Test the time wrapping logic for the MLT.
1100
1101
        Parameters
1102
        ----------
1103
        mlt_offset : float
1104
            MLT offset in hours
1105
1106
        """
1107
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1108
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1109
                                      self.in_time) + mlt_offset
1110
1111
        np.testing.assert_allclose(mlt1, mlt2)
1112
        return
1113
1114
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1115
    def test_mlt2mlon_offset(self, mlon_offset):
1116
        """Test the time wrapping logic for the magnetic longitude.
1117
1118
        Parameters
1119
        ----------
1120
        mlt_offset : float
1121
            MLT offset in hours
1122
1123
        """
1124
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1125
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1126
                                       self.in_time) - mlon_offset
1127
1128
        np.testing.assert_allclose(mlon1, mlon2)
1129
        return
1130
1131
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1132
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1133
    def test_convert_and_return(self, order, start_val):
1134
        """Test the conversion to magnetic longitude or MLT and back again.
1135
1136
        Parameters
1137
        ----------
1138
        order : list
1139
            List of strings specifying the order to run functions
1140
        start_val : int or float
1141
            Input value
1142
1143
        """
1144
        first_method = getattr(self.apex_out, "2".join(order))
1145
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1146
1147
        middle_val = first_method(start_val, self.in_time)
1148
        end_val = second_method(middle_val, self.in_time)
1149
1150
        np.testing.assert_allclose(start_val, end_val)
1151
        return
1152
1153
1154
class TestApexMapMethods(object):
1155
    """Test the Apex height mapping methods."""
1156
    def setup_method(self):
1157
        """Initialize all tests."""
1158
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1159
1160
    def teardown_method(self):
1161
        """Clean up after each test."""
1162
        del self.apex_out
1163
1164
    @pytest.mark.parametrize("in_args,test_mapped",
1165
                             [([60, 15, 100, 10000],
1166
                               [31.841466903686523, 17.916635513305664,
1167
                                1.7075473124350538e-6]),
1168
                              ([30, 170, 100, 500, False, 1e-2],
1169
                               [25.727270126342773, 169.60546875,
1170
                                0.00017573432705830783]),
1171
                              ([60, 15, 100, 10000, True],
1172
                               [-25.424888610839844, 27.310426712036133,
1173
                                1.2074182222931995e-6]),
1174
                              ([30, 170, 100, 500, True, 1e-2],
1175
                               [-13.76642894744873, 164.24259948730469,
1176
                                0.00056820799363777041])])
1177
    def test_map_to_height(self, in_args, test_mapped):
1178
        """Test the map_to_height function.
1179
1180
        Parameters
1181
        ----------
1182
        in_args : list
1183
            List of input arguments
1184
        test_mapped : list
1185
            List of expected outputs
1186
1187
        """
1188
        mapped = self.apex_out.map_to_height(*in_args)
1189
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1190
        return
1191
1192
    def test_map_to_height_same_height(self):
1193
        """Test the map_to_height function when mapping to same height."""
1194
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1195
                                             precision=1e-10)
1196
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1197
                                   rtol=1e-5, atol=1e-5)
1198
        return
1199
1200
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1201
    @pytest.mark.parametrize('ivec', range(0, 4))
1202
    def test_map_to_height_array_location(self, arr_shape, ivec):
1203
        """Test map_to_height with array input.
1204
1205
        Parameters
1206
        ----------
1207
        arr_shape : tuple
1208
            Expected array shape
1209
        ivec : int
1210
            Input argument index for vectorized input
1211
1212
        """
1213
        # Set the base input and output values
1214
        in_args = [60, 15, 100, 100]
1215
        test_mapped = [60, 15.00000381, 0.0]
1216
1217
        # Update inputs for one vectorized value
1218
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1219
1220
        # Calculate and test function
1221
        mapped = self.apex_out.map_to_height(*in_args)
1222
        for i, test_val in enumerate(test_mapped):
1223
            assert mapped[i].shape == arr_shape
1224
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1225
                                       atol=1e-5)
1226
        return
1227
1228
    @pytest.mark.parametrize("method_name,in_args",
1229
                             [("map_to_height", [0, 15, 100, 10000]),
1230
                              ("map_E_to_height",
1231
                               [0, 15, 100, 10000, [1, 2, 3]]),
1232
                              ("map_V_to_height",
1233
                               [0, 15, 100, 10000, [1, 2, 3]])])
1234
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1235
        """Test map_to_height raises ApexHeightError.
1236
1237
        Parameters
1238
        ----------
1239
        method_name : str
1240
            Name of the Apex class method to test
1241
        in_args : list
1242
            List of input arguments
1243
1244
        """
1245
        apex_method = getattr(self.apex_out, method_name)
1246
1247
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1248
            apex_method(*in_args)
1249
1250
        assert aerr.match("is > apex height")
1251
        return
1252
1253
    @pytest.mark.parametrize("method_name",
1254
                             ["map_E_to_height", "map_V_to_height"])
1255
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1256
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1257
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1258
        """Test height mapping of E/V with baddly shaped input raises Error.
1259
1260
        Parameters
1261
        ----------
1262
        method_name : str
1263
            Name of the Apex class method to test
1264
        ev_input : list
1265
            E/V input arguments
1266
1267
        """
1268
        apex_method = getattr(self.apex_out, method_name)
1269
        in_args = [60, 15, 100, 500, ev_input]
1270
        with pytest.raises(ValueError) as verr:
1271
            apex_method(*in_args)
1272
1273
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1274
        return
1275
1276
    def test_mapping_EV_bad_flag(self):
1277
        """Test _map_EV_to_height raises error for bad data type flag."""
1278
        with pytest.raises(ValueError) as verr:
1279
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1280
1281
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1282
        return
1283
1284
    @pytest.mark.parametrize("in_args,test_mapped",
1285
                             [([60, 15, 100, 500, [1, 2, 3]],
1286
                               [0.71152183, 2.35624876, 0.57260784]),
1287
                              ([60, 15, 100, 500, [2, 3, 4]],
1288
                               [1.56028502, 3.43916636, 0.78235384]),
1289
                              ([60, 15, 100, 1000, [1, 2, 3]],
1290
                               [0.67796492, 2.08982134, 0.55860785]),
1291
                              ([60, 15, 200, 500, [1, 2, 3]],
1292
                               [0.72377397, 2.42737471, 0.59083726]),
1293
                              ([60, 30, 100, 500, [1, 2, 3]],
1294
                               [0.68626344, 2.37530133, 0.60060124]),
1295
                              ([70, 15, 100, 500, [1, 2, 3]],
1296
                               [0.72760378, 2.18082305, 0.29141979])])
1297
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1298
        """Test mapping of E-field to a specified height.
1299
1300
        Parameters
1301
        ----------
1302
        in_args : list
1303
            List of input arguments
1304
        test_mapped : list
1305
            List of expected outputs
1306
1307
        """
1308
        mapped = self.apex_out.map_E_to_height(*in_args)
1309
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1310
        return
1311
1312
    @pytest.mark.parametrize('ev_flag, test_mapped',
1313
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1314
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1315
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1316
    @pytest.mark.parametrize('ivec', range(0, 5))
1317
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1318
                                             arr_shape, ivec):
1319
        """Test mapping of E-field/drift to a specified height with arrays.
1320
1321
        Parameters
1322
        ----------
1323
        ev_flag : str
1324
            Character flag specifying whether to run 'E' or 'V' methods
1325
        test_mapped : list
1326
            List of expected outputs
1327
        arr_shape : tuple
1328
            Shape of the expected output
1329
        ivec : int
1330
            Index of the expected output
1331
1332
        """
1333
        # Set the base input and output values
1334
        eshape = list(arr_shape)
1335
        eshape.insert(0, 3)
1336
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1337
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1338
1339
        # Update inputs for one vectorized value if this is a location input
1340
        if ivec < 4:
1341
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1342
1343
        # Get the mapped output
1344
        apex_method = getattr(self.apex_out,
1345
                              "map_{:s}_to_height".format(ev_flag))
1346
        mapped = apex_method(*in_args)
1347
1348
        # Test the results
1349
        for i, test_val in enumerate(test_mapped):
1350
            assert mapped[i].shape == arr_shape
1351
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1352
        return
1353
1354
    @pytest.mark.parametrize("in_args,test_mapped",
1355
                             [([60, 15, 100, 500, [1, 2, 3]],
1356
                               [0.81971957, 2.84512495, 0.69545001]),
1357
                              ([60, 15, 100, 500, [2, 3, 4]],
1358
                               [1.83027746, 4.14346436, 0.94764179]),
1359
                              ([60, 15, 100, 1000, [1, 2, 3]],
1360
                               [0.92457698, 3.14997661, 0.85135187]),
1361
                              ([60, 15, 200, 500, [1, 2, 3]],
1362
                               [0.80388262, 2.79321504, 0.68285158]),
1363
                              ([60, 30, 100, 500, [1, 2, 3]],
1364
                               [0.76141245, 2.87884673, 0.73655941]),
1365
                              ([70, 15, 100, 500, [1, 2, 3]],
1366
                               [0.84681866, 2.5925821, 0.34792655])])
1367
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1368
        """Test mapping of velocity to a specified height.
1369
1370
        Parameters
1371
        ----------
1372
        in_args : list
1373
            List of input arguments
1374
        test_mapped : list
1375
            List of expected outputs
1376
1377
        """
1378
        mapped = self.apex_out.map_V_to_height(*in_args)
1379
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1380
        return
1381
1382
1383
class TestApexBasevectorMethods(object):
1384
    """Test the Apex height base vector methods."""
1385
    def setup_method(self):
1386
        """Initialize all tests."""
1387
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1388
        self.lat = 60
1389
        self.lon = 15
1390
        self.height = 100
1391
        self.test_basevec = None
1392
1393
    def teardown_method(self):
1394
        """Clean up after each test."""
1395
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1396
1397
    def get_comparison_results(self, bv_coord, coords, precision):
1398
        """Get the base vector results using the hidden function for comparison.
1399
1400
        Parameters
1401
        ----------
1402
        bv_coord : str
1403
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1404
            or 'bvectors_apex'
1405
        coords : str
1406
            Expects one of 'geo', 'apex', or 'qd'
1407
        precision : float
1408
            Float specifiying precision
1409
1410
        """
1411
        if coords == "geo":
1412
            glat = self.lat
1413
            glon = self.lon
1414
        else:
1415
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1416
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1417
                                        precision=precision)
1418
1419
        if bv_coord == 'qd':
1420
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1421
        elif bv_coord == 'apex':
1422
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1423
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1424
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1425
        else:
1426
            # These are set results that need to be updated with IGRF
1427
            if coords == "geo":
1428
                self.test_basevec = (
1429
                    np.array([4.42368795e-05, 4.42368795e-05]),
1430
                    np.array([[0.01047826, 0.01047826],
1431
                              [0.33089194, 0.33089194],
1432
                              [-1.04941, -1.04941]]),
1433
                    np.array([5.3564698e-05, 5.3564698e-05]),
1434
                    np.array([[0.00865356, 0.00865356],
1435
                              [0.27327004, 0.27327004],
1436
                              [-0.8666646, -0.8666646]]))
1437
            elif coords == "apex":
1438
                self.test_basevec = (
1439
                    np.array([4.48672735e-05, 4.48672735e-05]),
1440
                    np.array([[-0.12510721, -0.12510721],
1441
                              [0.28945938, 0.28945938],
1442
                              [-1.1505738, -1.1505738]]),
1443
                    np.array([6.38577444e-05, 6.38577444e-05]),
1444
                    np.array([[-0.08790194, -0.08790194],
1445
                              [0.2033779, 0.2033779],
1446
                              [-0.808408, -0.808408]]))
1447
            else:
1448
                self.test_basevec = (
1449
                    np.array([4.46348578e-05, 4.46348578e-05]),
1450
                    np.array([[-0.12642345, -0.12642345],
1451
                              [0.29695055, 0.29695055],
1452
                              [-1.1517885, -1.1517885]]),
1453
                    np.array([6.38626285e-05, 6.38626285e-05]),
1454
                    np.array([[-0.08835986, -0.08835986],
1455
                              [0.20754464, 0.20754464],
1456
                              [-0.8050078, -0.8050078]]))
1457
1458
        return
1459
1460
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1461
    @pytest.mark.parametrize("coords,precision",
1462
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1463
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1464
        """Test the base vector calculations with scalars.
1465
1466
        Parameters
1467
        ----------
1468
        bv_coord : str
1469
            Name of the input coordinate system
1470
        coords : str
1471
            Name of the output coordinate system
1472
        precision : float
1473
            Level of run precision requested
1474
1475
        """
1476
        # Get the base vectors
1477
        base_method = getattr(self.apex_out,
1478
                              "basevectors_{:s}".format(bv_coord))
1479
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1480
                              precision=precision)
1481
        self.get_comparison_results(bv_coord, coords, precision)
1482
        if bv_coord == "apex":
1483
            basevec = list(basevec)
1484
            for i in range(4):
1485
                # Not able to compare indices 2, 3, 4, and 5
1486
                basevec.pop(2)
1487
1488
        # Test the results
1489
        for i, vec in enumerate(basevec):
1490
            np.testing.assert_allclose(vec, self.test_basevec[i])
1491
        return
1492
1493
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1494
    def test_basevectors_scalar_shape(self, bv_coord):
1495
        """Test the shape of the scalar output.
1496
1497
        Parameters
1498
        ----------
1499
        bv_coord : str
1500
            Name of the input coordinate system
1501
1502
        """
1503
        base_method = getattr(self.apex_out,
1504
                              "basevectors_{:s}".format(bv_coord))
1505
        basevec = base_method(self.lat, self.lon, self.height)
1506
1507
        for i, vec in enumerate(basevec):
1508
            if i < 2:
1509
                assert vec.shape == (2,)
1510
            else:
1511
                assert vec.shape == (3,)
1512
        return
1513
1514
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1515
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1516
    @pytest.mark.parametrize("ivec", range(3))
1517
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1518
        """Test the output shape for array inputs.
1519
1520
        Parameters
1521
        ----------
1522
        arr_shape : tuple
1523
            Expected output shape
1524
        bv_coord : str
1525
            Name of the input coordinate system
1526
        ivec : int
1527
            Index of the evaluated output value
1528
1529
        """
1530
        # Define the input arguments
1531
        in_args = [self.lat, self.lon, self.height]
1532
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1533
1534
        # Get the basevectors
1535
        base_method = getattr(self.apex_out,
1536
                              "basevectors_{:s}".format(bv_coord))
1537
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1538
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1539
        if bv_coord == "apex":
1540
            basevec = list(basevec)
1541
            for i in range(4):
1542
                # Not able to compare indices 2, 3, 4, and 5
1543
                basevec.pop(2)
1544
1545
        # Evaluate the shape and the values
1546
        for i, vec in enumerate(basevec):
1547
            test_shape = list(arr_shape)
1548
            test_shape.insert(0, 2 if i < 2 else 3)
1549
            assert vec.shape == tuple(test_shape)
1550
            assert np.all(self.test_basevec[i][0] == vec[0])
1551
            assert np.all(self.test_basevec[i][1] == vec[1])
1552
        return
1553
1554
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1555
    def test_bvectors_apex(self, coords):
1556
        """Test the bvectors_apex method.
1557
1558
        Parameters
1559
        ----------
1560
        coords : str
1561
            Name of the coordiante system
1562
1563
        """
1564
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1565
                   [self.height, self.height]]
1566
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1567
1568
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1569
                                              precision=1e-10)
1570
        for i, vec in enumerate(basevec):
1571
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1572
                                                 decimal=5)
1573
        return
1574
1575
    def test_basevectors_apex_extra_values(self):
1576
        """Test specific values in the apex base vector output."""
1577
        # Set the testing arrays
1578
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1579
                             np.array([0.939012, 0.073416, -0.07342]),
1580
                             np.array([0.055389, 1.004155, 0.257594]),
1581
                             np.array([0, 0, 1.065135])]
1582
1583
        # Get the desired output
1584
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1585
1586
        # Test the values not covered by `test_basevectors_scalar`
1587
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1588
            np.testing.assert_allclose(basevec[ibase],
1589
                                       self.test_basevec[itest], rtol=1e-4)
1590
        return
1591
1592
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1593
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1594
    def test_basevectors_apex_delta(self, lat, lon):
1595
        """Test that vectors are calculated correctly.
1596
1597
        Parameters
1598
        ----------
1599
        lat : int or float
1600
            Latitude in degrees N
1601
        lon : int or float
1602
            Longitude in degrees E
1603
1604
        """
1605
        # Get the apex base vectors and sort them for easy testing
1606
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1607
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1608
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1609
        gvec = [g1, g2, g3]
1610
        dvec = [d1, d2, d3]
1611
        evec = [e1, e2, e3]
1612
1613
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1614
            delta = 1 if idelta == jdelta else 0
1615
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1616
                                       delta, rtol=0, atol=1e-5)
1617
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1618
                                       delta, rtol=0, atol=1e-5)
1619
        return
1620
1621
    def test_basevectors_apex_invalid_scalar(self):
1622
        """Test warning and fill values for base vectors with bad inputs."""
1623
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1624
        invalid = np.full(shape=(3,), fill_value=np.nan)
1625
1626
        # Get the output and the warnings
1627
        with warnings.catch_warnings(record=True) as warn_rec:
1628
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1629
1630
        for i, bvec in enumerate(basevec):
1631
            if i < 2:
1632
                assert not np.allclose(bvec, invalid[:2])
1633
            else:
1634
                np.testing.assert_allclose(bvec, invalid)
1635
1636
        assert issubclass(warn_rec[-1].category, UserWarning)
1637
        assert 'set to NaN where' in str(warn_rec[-1].message)
1638
        return
1639
1640
1641
class TestApexGetMethods(object):
1642
    """Test the Apex `get` methods."""
1643
    def setup_method(self):
1644
        """Initialize all tests."""
1645
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1646
1647
    def teardown_method(self):
1648
        """Clean up after each test."""
1649
        del self.apex_out
1650
1651
    @pytest.mark.parametrize("alat, aheight",
1652
                             [(10, 507.409702543805),
1653
                              (60, 20313.026999999987),
1654
                              ([10, 60],
1655
                               [507.409702543805, 20313.026999999987]),
1656
                              ([[10], [60]],
1657
                               [[507.409702543805], [20313.026999999987]])])
1658
    def test_get_apex(self, alat, aheight):
1659
        """Test the apex height retrieval results.
1660
1661
        Parameters
1662
        ----------
1663
        alat : int or float
1664
            Apex latitude in degrees N
1665
        aheight : int or float
1666
            Apex height in km
1667
1668
        """
1669
        alt = self.apex_out.get_apex(alat)
1670
        np.testing.assert_allclose(alt, aheight)
1671
        return
1672
1673
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1674
                             [([80], [100], [300], 5.100682377815247e-05),
1675
                              ([80, 80], [100], [300],
1676
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1677
                              ([[80], [80]], [100], [300],
1678
                               [[5.100682377815247e-05],
1679
                                [5.100682377815247e-05]]),
1680
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1681
                               np.array([4.18657154e-05, 5.11118114e-05,
1682
                                         4.91969854e-05, 5.10519207e-05,
1683
                                         4.90054816e-05])),
1684
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1685
    def test_get_babs(self, glat, glon, height, test_bmag):
1686
        """Test the method to get the magnitude of the magnetic field.
1687
1688
        Parameters
1689
        ----------
1690
        glat : list
1691
            List of latitudes in degrees N
1692
        glon : list
1693
            List of longitudes in degrees E
1694
        height : list
1695
            List of heights in km
1696
        test_bmag : float
1697
            Expected B field magnitude
1698
1699
        """
1700
        bmag = self.apex_out.get_babs(glat, glon, height)
1701
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1702
        return
1703
1704
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1705
    def test_get_apex_with_invalid_lat(self, bad_lat):
1706
        """Test get methods raise ValueError for invalid latitudes.
1707
1708
        Parameters
1709
        ----------
1710
        bad_lat : int or float
1711
            Bad input latitude in degrees N
1712
1713
        """
1714
1715
        with pytest.raises(ValueError) as verr:
1716
            self.apex_out.get_apex(bad_lat)
1717
1718
        assert str(verr.value).find("must be in [-90, 90]") > 0
1719
        return
1720
1721
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1722
    def test_get_babs_with_invalid_lat(self, bad_lat):
1723
        """Test get methods raise ValueError for invalid latitudes.
1724
1725
        Parameters
1726
        ----------
1727
        bad_lat : int or float
1728
            Bad input latitude in degrees N
1729
1730
        """
1731
1732
        with pytest.raises(ValueError) as verr:
1733
            self.apex_out.get_babs(bad_lat, 15, 100)
1734
1735
        assert str(verr.value).find("must be in [-90, 90]") > 0
1736
        return
1737
1738
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1739
    def test_get_at_lat_boundary(self, bound_lat):
1740
        """Test get methods at the latitude boundary, with allowed excess.
1741
1742
        Parameters
1743
        ----------
1744
        bound_lat : int or float
1745
            Boundary input latitude in degrees N
1746
1747
        """
1748
        # Get a latitude just beyond the limit
1749
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1750
1751
        # Get the two outputs, slight tolerance outside of boundary allowed
1752
        bound_out = self.apex_out.get_apex(bound_lat)
1753
        excess_out = self.apex_out.get_apex(excess_lat)
1754
1755
        # Test the outputs
1756
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1757
        return
1758
1759
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1760
    def test_get_height_at_equator(self, apex_height):
1761
        """Test that `get_height` returns apex height at equator.
1762
1763
        Parameters
1764
        ----------
1765
        apex_height : float
1766
            Apex height
1767
1768
        """
1769
1770
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1771
        return
1772
1773
    @pytest.mark.parametrize("lat, height", [
1774
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1775
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1776
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1777
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1778
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1779
        (30, 657.2477500000014), (40, -871.8751821247979),
1780
        (50, -2499.1338178752017), (60, -4028.256749999999),
1781
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1782
    def test_get_height_along_fieldline(self, lat, height):
1783
        """Test that `get_height` returns expected height of field line.
1784
1785
        Parameters
1786
        ----------
1787
        lat : float
1788
            Input latitude
1789
        height : float
1790
            Output field-line height for line with apex of 3000 km
1791
1792
        """
1793
1794
        fheight = self.apex_out.get_height(lat, 3000.0)
1795
        assert abs(height - fheight) < 1.0e-7, \
1796
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1797
        return
1798
1799
1800
class TestApexMethodExtrapolateIGRF(object):
1801
    """Test the Apex methods on a year when IGRF must be extrapolated.
1802
1803
    Notes
1804
    -----
1805
    Extrapolation should be done using a year within 5 years of the latest IGRF
1806
    model epoch.
1807
1808
    """
1809
1810
    def setup_method(self):
1811
        """Initialize all tests."""
1812
        self.apex_out = apexpy.Apex(date=2025, refh=300)
1813
        self.in_lat = 60
1814
        self.in_lon = 15
1815
        self.in_alt = 100
1816
        self.in_time = dt.datetime(2024, 2, 3, 4, 5, 6)
1817
        return
1818
1819
    def teardown_method(self):
1820
        """Clean up after each test."""
1821
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
1822
        return
1823
1824 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1825
                             [("geo2apex",
1826
                               (56.25343704223633, 92.04932403564453)),
1827
                              ("apex2geo",
1828
                               (53.84184265136719, -66.93045806884766,
1829
                                3.6222547805664362e-06)),
1830
                              ("geo2qd",
1831
                               (56.82968521118164, 92.04932403564453)),
1832
                              ("apex2qd", (60.498401178276744, 15.0)),
1833
                              ("qd2apex", (59.49138097045895, 15.0))])
1834
    def test_method_scalar_input(self, method_name, out_comp):
1835
        """Test the user method against set values with scalars.
1836
1837
        Parameters
1838
        ----------
1839
        method_name : str
1840
            Apex class method to be tested
1841
        out_comp : tuple of floats
1842
            Expected output values
1843
1844
        """
1845
        # Get the desired methods
1846
        user_method = getattr(self.apex_out, method_name)
1847
1848
        # Get the user output
1849
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
1850
1851
        # Evaluate the user output
1852
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1853
1854
        for out_val in user_out:
1855
            assert np.asarray(out_val).shape == (), "output is not a scalar"
1856
        return
1857
1858
    def test_convert_to_mlt(self):
1859
        """Test conversion from mlon to mlt with scalars."""
1860
1861
        # Get user output
1862
        user_out = self.apex_out.mlon2mlt(self.in_lon, self.in_time)
1863
1864
        # Set comparison values
1865
        out_comp = 23.955474853515625
1866
1867
        # Evaluate user output
1868
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1869
        return
1870