Completed
Push — develop ( 7071db...61304a )
by
unknown
15s queued 13s
created

test_Apex.TestApexInit.test_igrf_fn()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""Test the apexpy.Apex class
3
4
Notes
5
-----
6
Whenever function outputs are tested against hard-coded numbers, the test
7
results (numbers) were obtained by running the code that is tested.  Therefore,
8
these tests below only check that nothing changes when refactoring, etc., and
9
not if the results are actually correct.
10
11
These results are expected to change when IGRF is updated.
12
13
"""
14
15
import datetime as dt
16
import numpy as np
17
import os
18
import pytest
19
import shutil
20
import warnings
21
22
import apexpy
23
24
25
class TestApexInit(object):
26
    """Test class for the Apex class object."""
27
28
    def setup_method(self):
29
        """Initialize all tests."""
30
        self.apex_out = None
31
        self.test_date = dt.datetime.utcnow()
32
        self.test_refh = 0
33
        self.bad_file = 'foo/path/to/datafile.blah'
34
        self.temp_file = None
35
36
    def teardown_method(self):
37
        """Clean up after each test."""
38
        del self.apex_out, self.test_date, self.test_refh, self.bad_file, \
39
            self.temp_file
40
41
    def eval_date(self):
42
        """Evaluate the times in self.test_date and self.apex_out."""
43
        if isinstance(self.test_date, dt.datetime) \
44
           or isinstance(self.test_date, dt.date):
45
            self.test_date = apexpy.helpers.toYearFraction(self.test_date)
46
47
        # Assert the times are the same on the order of tens of seconds.
48
        # Necessary to evaluate the current UTC
49
        np.testing.assert_almost_equal(self.test_date, self.apex_out.year, 6)
50
        return
51
52
    def eval_refh(self):
53
        """Evaluate the reference height in self.refh and self.apex_out."""
54
        eval_str = "".join(["expected reference height [",
55
                            "{:}] not equal to Apex ".format(self.test_refh),
56
                            "reference height ",
57
                            "[{:}]".format(self.apex_out.refh)])
58
        assert self.test_refh == self.apex_out.refh, eval_str
59
        return
60
61
    def test_init_defaults(self):
62
        """Test Apex class default initialization."""
63
        self.apex_out = apexpy.Apex()
64
        self.eval_date()
65
        self.eval_refh()
66
        return
67
68
    def test_init_today(self):
69
        """Test Apex class initialization with today's date."""
70
        self.apex_out = apexpy.Apex(date=self.test_date)
71
        self.eval_date()
72
        self.eval_refh()
73
        return
74
75
    @pytest.mark.parametrize("in_date",
76
                             [2015, 2015.5, dt.date(2015, 1, 1),
77
                              dt.datetime(2015, 6, 1, 18, 23, 45)])
78
    def test_init_date(self, in_date):
79
        """Test Apex class with date initialization.
80
81
        Parameters
82
        ----------
83
        in_date : int, float, dt.date, or dt.datetime
84
            Input date in a variety of formats
85
86
        """
87
        self.test_date = in_date
88
        self.apex_out = apexpy.Apex(date=self.test_date)
89
        self.eval_date()
90
        self.eval_refh()
91
        return
92
93
    @pytest.mark.parametrize("new_date", [2015, 2015.5])
94
    def test_set_epoch(self, new_date):
95
        """Test successful setting of Apex epoch after initialization.
96
97
        Parameters
98
        ----------
99
        new_date : int or float
100
            New date for the Apex class
101
102
        """
103
        # Evaluate the default initialization
104
        self.apex_out = apexpy.Apex()
105
        self.eval_date()
106
        self.eval_refh()
107
108
        # Update the epoch
109
        ref_apex = eval(self.apex_out.__repr__())
110
        self.apex_out.set_epoch(new_date)
111
        assert ref_apex != self.apex_out
112
        self.test_date = new_date
113
        self.eval_date()
114
        return
115
116
    @pytest.mark.parametrize("in_refh", [0.0, 300.0, 30000.0, -1.0])
117
    def test_init_refh(self, in_refh):
118
        """Test Apex class with reference height initialization.
119
120
        Parameters
121
        ----------
122
        in_refh : float
123
            Input reference height in km
124
125
        """
126
        self.test_refh = in_refh
127
        self.apex_out = apexpy.Apex(refh=self.test_refh)
128
        self.eval_date()
129
        self.eval_refh()
130
        return
131
132
    @pytest.mark.parametrize("new_refh", [0.0, 300.0, 30000.0, -1.0])
133
    def test_set_refh(self, new_refh):
134
        """Test the method used to set the reference height after the init.
135
136
        Parameters
137
        ----------
138
        new_refh : float
139
            Reference height in km
140
141
        """
142
        # Verify the defaults are set
143
        self.apex_out = apexpy.Apex(date=self.test_date)
144
        self.eval_date()
145
        self.eval_refh()
146
147
        # Update to a new reference height and test
148
        ref_apex = eval(self.apex_out.__repr__())
149
        self.apex_out.set_refh(new_refh)
150
151
        if self.test_refh == new_refh:
152
            assert ref_apex == self.apex_out
153
        else:
154
            assert ref_apex != self.apex_out
155
            self.test_refh = new_refh
156
        self.eval_refh()
157
        return
158
159
    def copy_file(self, original, max_attempts=100):
160
        """Make a temporary copy of input file."""
161
        _, ext = os.path.splitext(original)
162
        self.temp_file = 'temp' + ext
163
164
        for _ in range(max_attempts):
165
            try:
166
                shutil.copy(original, self.temp_file)
167
                break
168
            except Exception:
169
                pass
170
171
        return
172
173
    def test_default_datafile(self):
174
        """Test that the class initializes with the default datafile."""
175
        self.apex_out = apexpy.Apex()
176
        assert os.path.isfile(self.apex_out.datafile)
177
        return
178
179
    def test_custom_datafile(self):
180
        """Test that the class initializes with a good datafile input."""
181
182
        # Get the original datafile name
183
        apex_out_orig = apexpy.Apex()
184
        original_file = apex_out_orig.datafile
185
        del apex_out_orig
186
187
        # Create copy of datafile
188
        self.copy_file(original_file)
189
190
        self.apex_out = apexpy.Apex(datafile=self.temp_file)
191
        assert self.apex_out.datafile == self.temp_file
192
193
        return
194
195
    def test_init_with_bad_datafile(self):
196
        """Test raises IOError with non-existent datafile input."""
197
        with pytest.raises(IOError) as oerr:
198
            apexpy.Apex(datafile=self.bad_file)
199
        assert str(oerr.value).startswith('Data file does not exist')
200
        return
201
202
    def test_default_fortranlib(self):
203
        """Test that the class initializes with the default fortranlib."""
204
        self.apex_out = apexpy.Apex()
205
        assert os.path.isfile(self.apex_out.fortranlib)
206
        return
207
208
    def test_custom_fortranlib(self):
209
        """Test that the class initializes with a good fortranlib input."""
210
211
        # Get the original fortranlib name
212
        apex_out_orig = apexpy.Apex()
213
        original_lib = apex_out_orig.fortranlib
214
        del apex_out_orig
215
216
        # Create copy of datafile
217
        self.copy_file(original_lib)
218
219
        self.apex_out = apexpy.Apex(fortranlib=self.temp_file)
220
        assert self.apex_out.fortranlib == self.temp_file
221
222
        return
223
224
    def test_init_with_bad_fortranlib(self):
225
        """Test raises IOError with non-existent fortranlib input."""
226
        with pytest.raises(IOError) as oerr:
227
            apexpy.Apex(fortranlib=self.bad_file)
228
        assert str(oerr.value).startswith('Fortran library does not exist')
229
        return
230
231
    def test_igrf_fn(self):
232
        """Test the default igrf_fn."""
233
        self.apex_out = apexpy.Apex()
234
        assert os.path.isfile(self.apex_out.igrf_fn)
235
        return
236
237
    def test_repr_eval(self):
238
        """Test the Apex.__repr__ results."""
239
        # Initialize the apex object
240
        self.apex_out = apexpy.Apex()
241
        self.eval_date()
242
        self.eval_refh()
243
244
        # Get and test the repr string
245
        out_str = self.apex_out.__repr__()
246
        assert out_str.find("apexpy.Apex(") == 0
247
248
        # Test the ability to re-create the apex object from the repr string
249
        new_apex = eval(out_str)
250
        assert new_apex == self.apex_out
251
        return
252
253
    def test_ne_other_class(self):
254
        """Test Apex class inequality to a different class."""
255
        self.apex_out = apexpy.Apex()
256
        self.eval_date()
257
        self.eval_refh()
258
259
        assert self.apex_out != self.test_date
260
        return
261
262
    def test_ne_missing_attr(self):
263
        """Test Apex class inequality when attributes are missing from one."""
264
        self.apex_out = apexpy.Apex()
265
        self.eval_date()
266
        self.eval_refh()
267
        ref_apex = eval(self.apex_out.__repr__())
268
        del ref_apex.RE
269
270
        assert ref_apex != self.apex_out
271
        assert self.apex_out != ref_apex
272
        return
273
274
    def test_eq_missing_attr(self):
275
        """Test Apex class equality when attributes are missing from both."""
276
        self.apex_out = apexpy.Apex()
277
        self.eval_date()
278
        self.eval_refh()
279
        ref_apex = eval(self.apex_out.__repr__())
280
        del ref_apex.RE, self.apex_out.RE
281
282
        assert ref_apex == self.apex_out
283
        return
284
285
    def test_str_eval(self):
286
        """Test the Apex.__str__ results."""
287
        # Initialize the apex object
288
        self.apex_out = apexpy.Apex()
289
        self.eval_date()
290
        self.eval_refh()
291
292
        # Get and test the printed string
293
        out_str = self.apex_out.__str__()
294
        assert out_str.find("Decimal year") > 0
295
        return
296
297
298
class TestApexMethod(object):
299
    """Test the Apex methods."""
300
    def setup_method(self):
301
        """Initialize all tests."""
302
        self.apex_out = apexpy.Apex(date=2000, refh=300)
303
        self.in_lat = 60
304
        self.in_lon = 15
305
        self.in_alt = 100
306
307
    def teardown_method(self):
308
        """Clean up after each test."""
309
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
310
311
    def get_input_args(self, method_name, precision=0.0):
312
        """Set the input arguments for the different Apex methods.
313
314
        Parameters
315
        ----------
316
        method_name : str
317
            Name of the Apex class method
318
        precision : float
319
            Value for the precision (default=0.0)
320
321
        Returns
322
        -------
323
        in_args : list
324
            List of the appropriate input arguments
325
326
        """
327
        in_args = [self.in_lat, self.in_lon, self.in_alt]
328
329
        # Add precision, if needed
330
        if method_name in ["_qd2geo", "apxq2g", "apex2geo", "qd2geo",
331
                           "_apex2geo"]:
332
            in_args.append(precision)
333
334
        # Add a reference height, if needed
335
        if method_name in ["apxg2all"]:
336
            in_args.append(300)
337
338
        # Add a vector flag, if needed
339
        if method_name in ["apxg2all", "apxg2q"]:
340
            in_args.append(1)
341
342
        return in_args
343
344
    def test_apex_conversion_today(self):
345
        """Test Apex class conversion with today's date."""
346
        self.apex_out = apexpy.Apex(date=dt.datetime.utcnow(), refh=300)
347
        assert not np.isnan(self.apex_out.geo2apex(self.in_lat, self.in_lon,
348
                                                   self.in_alt)).any()
349
        return
350
351
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
352
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
353
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
354
                              ("_qd2geo", "apxq2g", slice(None)),
355
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
356
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
357
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
358
    def test_fortran_scalar_input(self, apex_method, fortran_method, fslice,
359
                                  lat, lon):
360
        """Tests Apex/fortran interface consistency for scalars.
361
362
        Parameters
363
        ----------
364
        apex_method : str
365
            Name of the Apex class method to test
366
        fortran_method : str
367
            Name of the Fortran function to test
368
        fslice : slice
369
            Slice used select the appropriate Fortran outputs
370
        lat : int or float
371
            Latitude in degrees N
372
        lon : int or float
373
            Longitude in degrees E
374
375
        """
376
        # Set the input coordinates
377
        self.in_lat = lat
378
        self.in_lon = lon
379
380
        # Get the Apex class method and the fortran function call
381
        apex_func = getattr(self.apex_out, apex_method)
382
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
383
384
        # Get the appropriate input arguments
385
        apex_args = self.get_input_args(apex_method)
386
        fortran_args = self.get_input_args(fortran_method)
387
388
        # Evaluate the equivalent function calls
389
        np.testing.assert_allclose(apex_func(*apex_args),
390
                                   fortran_func(*fortran_args)[fslice])
391
        return
392
393
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
394
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
395
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
396
                              ("_qd2geo", "apxq2g", slice(None)),
397
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
398
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
399
    @pytest.mark.parametrize("lon1,lon2", [(180, 180), (-180, -180),
400
                                           (180, -180), (-180, 180),
401
                                           (-345, 15), (375, 15)])
402
    def test_fortran_longitude_rollover(self, apex_method, fortran_method,
403
                                        fslice, lat, lon1, lon2):
404
        """Tests Apex/fortran interface consistency for longitude rollover.
405
406
        Parameters
407
        ----------
408
        apex_method : str
409
            Name of the Apex class method to test
410
        fortran_method : str
411
            Name of the Fortran function to test
412
        fslice : slice
413
            Slice used select the appropriate Fortran outputs
414
        lat : int or float
415
            Latitude in degrees N
416
        lon1 : int or float
417
            Longitude in degrees E
418
        lon2 : int or float
419
            Equivalent longitude in degrees E
420
421
        """
422
        # Set the fixed input coordinate
423
        self.in_lat = lat
424
425
        # Get the Apex class method and the fortran function call
426
        apex_func = getattr(self.apex_out, apex_method)
427
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
428
429
        # Get the appropriate input arguments
430
        self.in_lon = lon1
431
        apex_args = self.get_input_args(apex_method)
432
433
        self.in_lon = lon2
434
        fortran_args = self.get_input_args(fortran_method)
435
436
        # Evaluate the equivalent function calls
437
        np.testing.assert_allclose(apex_func(*apex_args),
438
                                   fortran_func(*fortran_args)[fslice])
439
        return
440
441
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
442
    @pytest.mark.parametrize("apex_method,fortran_method,fslice",
443
                             [("_geo2qd", "apxg2q", slice(0, 2, 1)),
444
                              ("_geo2apex", "apxg2all", slice(2, 4, 1)),
445
                              ("_qd2geo", "apxq2g", slice(None)),
446
                              ("_basevec", "apxg2q", slice(2, 4, 1))])
447
    def test_fortran_array_input(self, arr_shape, apex_method, fortran_method,
448
                                 fslice):
449
        """Tests Apex/fortran interface consistency for array input.
450
451
        Parameters
452
        ----------
453
        arr_shape : tuple
454
            Expected output shape
455
        apex_method : str
456
            Name of the Apex class method to test
457
        fortran_method : str
458
            Name of the Fortran function to test
459
        fslice : slice
460
            Slice used select the appropriate Fortran outputs
461
462
        """
463
        # Get the Apex class method and the fortran function call
464
        apex_func = getattr(self.apex_out, apex_method)
465
        fortran_func = getattr(apexpy.fortranapex, fortran_method)
466
467
        # Set up the input arrays
468
        ref_lat = np.array([0, 30, 60, 90])
469
        ref_alt = np.array([100, 200, 300, 400])
470
        self.in_lat = ref_lat.reshape(arr_shape)
471
        self.in_alt = ref_alt.reshape(arr_shape)
472
        apex_args = self.get_input_args(apex_method)
473
474
        # Get the Apex class results
475
        aret = apex_func(*apex_args)
476
477
        # Get the fortran function results
478
        flats = list()
479
        flons = list()
480
481
        for i, lat in enumerate(ref_lat):
482
            self.in_lat = lat
483
            self.in_alt = ref_alt[i]
484
            fortran_args = self.get_input_args(fortran_method)
485
            fret = fortran_func(*fortran_args)[fslice]
486
            flats.append(fret[0])
487
            flons.append(fret[1])
488
489
        flats = np.array(flats)
490
        flons = np.array(flons)
491
492
        # Evaluate results
493
        try:
494
            # This returned value is array of floats
495
            np.testing.assert_allclose(aret[0].astype(float),
496
                                       flats.reshape(arr_shape).astype(float))
497
            np.testing.assert_allclose(aret[1].astype(float),
498
                                       flons.reshape(arr_shape).astype(float))
499
        except ValueError:
500
            # This returned value is array of arrays
501
            alats = aret[0].reshape((4,))
502
            alons = aret[1].reshape((4,))
503
            for i, flat in enumerate(flats):
504
                np.testing.assert_array_almost_equal(alats[i], flat, 2)
505
                np.testing.assert_array_almost_equal(alons[i], flons[i], 2)
506
507
        return
508
509
    @pytest.mark.parametrize("lat", [0, 30, 60, 89])
510
    @pytest.mark.parametrize("lon", [-179, -90, 0, 90, 180])
511
    def test_geo2apexall_scalar(self, lat, lon):
512
        """Test Apex/fortran geo2apexall interface consistency for scalars.
513
514
        Parameters
515
        ----------
516
        lat : int or float
517
            Latitude in degrees N
518
        long : int or float
519
            Longitude in degrees E
520
521
        """
522
        # Get the Apex and Fortran results
523
        aret = self.apex_out._geo2apexall(lat, lon, self.in_alt)
524
        fret = apexpy.fortranapex.apxg2all(lat, lon, self.in_alt, 300, 1)
525
526
        # Evaluate each element in the results
527
        for aval, fval in zip(aret, fret):
528
            np.testing.assert_allclose(aval, fval)
529
530
    @pytest.mark.parametrize("arr_shape", [(2, 2), (4,), (1, 4)])
531
    def test_geo2apexall_array(self, arr_shape):
532
        """Test Apex/fortran geo2apexall interface consistency for arrays.
533
534
        Parameters
535
        ----------
536
        arr_shape : tuple
537
            Expected output shape
538
539
        """
540
        # Set the input
541
        self.in_lat = np.array([0, 30, 60, 90])
542
        self.in_alt = np.array([100, 200, 300, 400])
543
544
        # Get the Apex class results
545
        aret = self.apex_out._geo2apexall(self.in_lat.reshape(arr_shape),
546
                                          self.in_lon,
547
                                          self.in_alt.reshape(arr_shape))
548
549
        # For each lat/alt pair, get the Fortran results
550
        fret = list()
551
        for i, lat in enumerate(self.in_lat):
552
            fret.append(apexpy.fortranapex.apxg2all(lat, self.in_lon,
553
                                                    self.in_alt[i], 300, 1))
554
555
        # Cycle through all returned values
556
        for i, ret in enumerate(aret):
557
            try:
558
                # This returned value is array of floats
559
                fret_test = np.array([fret[0][i], fret[1][i], fret[2][i],
560
                                      fret[3][i]]).reshape(arr_shape)
561
                np.testing.assert_allclose(ret.astype(float),
562
                                           fret_test.astype(float))
563
            except ValueError:
564
                # This returned value is array of arrays
565
                ret = ret.reshape((4,))
566
                for j, single_fret in enumerate(fret):
567
                    np.testing.assert_allclose(ret[j], single_fret[i])
568
        return
569
570
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
571
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
572
    def test_convert_consistency(self, in_coord, out_coord):
573
        """Test the self-consistency of the Apex convert method.
574
575
        Parameters
576
        ----------
577
        in_coord : str
578
            Input coordinate system
579
        out_coord : str
580
            Output coordinate system
581
582
        """
583
        if in_coord == out_coord:
584
            pytest.skip("Test not needed for same src and dest coordinates")
585
586
        # Define the method name
587
        method_name = "2".join([in_coord, out_coord])
588
589
        # Get the method and method inputs
590
        convert_kwargs = {'height': self.in_alt, 'precision': 0.0}
591
        apex_args = self.get_input_args(method_name)
592
        apex_method = getattr(self.apex_out, method_name)
593
594
        # Define the slice needed to get equivalent output from the named method
595
        mslice = slice(0, -1, 1) if out_coord == "geo" else slice(None)
596
597
        # Get output using convert and named method
598
        convert_out = self.apex_out.convert(self.in_lat, self.in_lon, in_coord,
599
                                            out_coord, **convert_kwargs)
600
        method_out = apex_method(*apex_args)[mslice]
601
602
        # Compare both outputs, should be identical
603
        np.testing.assert_allclose(convert_out, method_out)
604
        return
605
606
    @pytest.mark.parametrize("bound_lat", [90, -90])
607
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
608
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
609
    def test_convert_at_lat_boundary(self, bound_lat, in_coord, out_coord):
610
        """Test the conversion at the latitude boundary, with allowed excess.
611
612
        Parameters
613
        ----------
614
        bound_lat : int or float
615
            Boundary latitude in degrees N
616
        in_coord : str
617
            Input coordinate system
618
        out_coord : str
619
            Output coordinate system
620
621
        """
622
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
623
624
        # Get the two outputs, slight tolerance outside of boundary allowed
625
        bound_out = self.apex_out.convert(bound_lat, 0, in_coord, out_coord)
626
        excess_out = self.apex_out.convert(excess_lat, 0, in_coord, out_coord)
627
628
        # Test the outputs
629
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
630
        return
631
632
    def test_convert_qd2apex_at_equator(self):
633
        """Test the quasi-dipole to apex conversion at the magnetic equator."""
634
        eq_out = self.apex_out.convert(lat=0.0, lon=0, source='qd', dest='apex',
635
                                       height=320.0)
636
        close_out = self.apex_out.convert(lat=0.001, lon=0, source='qd',
637
                                          dest='apex', height=320.0)
638
        np.testing.assert_allclose(eq_out, close_out, atol=1e-4)
639
        return
640
641
    @pytest.mark.parametrize("src", ["geo", "apex", "qd"])
642
    @pytest.mark.parametrize("dest", ["geo", "apex", "qd"])
643
    def test_convert_withnan(self, src, dest):
644
        """Test Apex.convert success with NaN input.
645
646
        Parameters
647
        ----------
648
        src : str
649
            Input coordinate system
650
        dest : str
651
            Output coordinate system
652
653
        """
654
        if src == dest:
655
            pytest.skip("Test not needed for same src and dest coordinates")
656
657
        num_nans = 5
658
        in_loc = np.arange(0, 10, dtype=float)
659
        in_loc[:num_nans] = np.nan
660
661
        out_loc = self.apex_out.convert(in_loc, in_loc, src, dest, height=320)
662
663
        for out in out_loc:
664
            assert np.all(np.isnan(out[:num_nans])), "NaN output expected"
665
            assert np.all(np.isfinite(out[num_nans:])), "Finite output expected"
666
667
        return
668
669
    @pytest.mark.parametrize("bad_lat", [91, -91])
670
    def test_convert_invalid_lat(self, bad_lat):
671
        """Test convert raises ValueError for invalid latitudes.
672
673
        Parameters
674
        ----------
675
        bad_lat : int or float
676
            Latitude ouside the supported range in degrees N
677
678
        """
679
680
        with pytest.raises(ValueError) as verr:
681
            self.apex_out.convert(bad_lat, 0, 'geo', 'geo')
682
683
        assert str(verr.value).find("must be in [-90, 90]") > 0
684
        return
685
686
    @pytest.mark.parametrize("coords", [("foobar", "geo"), ("geo", "foobar"),
687
                                        ("geo", "mlt")])
688
    def test_convert_invalid_transformation(self, coords):
689
        """Test raises NotImplementedError for bad coordinates.
690
691
        Parameters
692
        ----------
693
        coords : tuple
694
            Tuple specifying the input and output coordinate systems
695
696
        """
697
        if "mlt" in coords:
698
            estr = "datetime must be given for MLT calculations"
699
        else:
700
            estr = "Unknown coordinate transformation"
701
702
        with pytest.raises(ValueError) as verr:
703
            self.apex_out.convert(0, 0, *coords)
704
705
        assert str(verr).find(estr) >= 0
706
        return
707
708 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
709
                             [("geo2apex",
710
                               (55.94841766357422, 94.10684204101562)),
711
                              ("apex2geo",
712
                               (51.476322174072266, -66.22817993164062,
713
                                5.727287771151168e-06)),
714
                              ("geo2qd",
715
                               (56.531288146972656, 94.10684204101562)),
716
                              ("apex2qd", (60.498401178276744, 15.0)),
717
                              ("qd2apex", (59.49138097045895, 15.0))])
718
    def test_method_scalar_input(self, method_name, out_comp):
719
        """Test the user method against set values with scalars.
720
721
        Parameters
722
        ----------
723
        method_name : str
724
            Apex class method to be tested
725
        out_comp : tuple of floats
726
            Expected output values
727
728
        """
729
        # Get the desired methods
730
        user_method = getattr(self.apex_out, method_name)
731
732
        # Get the user output
733
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
734
735
        # Evaluate the user output
736
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
737
738
        for out_val in user_out:
739
            assert np.asarray(out_val).shape == (), "output is not a scalar"
740
        return
741
742
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
743
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
744
    @pytest.mark.parametrize("method_args, out_shape",
745
                             [([[60, 60], 15, 100], (2,)),
746
                              ([60, [15, 15], 100], (2,)),
747
                              ([60, 15, [100, 100]], (2,)),
748
                              ([[50, 60], [15, 16], [100, 200]], (2,))])
749
    def test_method_broadcast_input(self, in_coord, out_coord, method_args,
750
                                    out_shape):
751
        """Test the user method with inputs that require some broadcasting.
752
753
        Parameters
754
        ----------
755
        in_coord : str
756
            Input coordiante system
757
        out_coord : str
758
            Output coordiante system
759
        method_args : list
760
            List of input arguments
761
        out_shape : tuple
762
            Expected shape of output values
763
764
        """
765
        if in_coord == out_coord:
766
            pytest.skip("Test not needed for same src and dest coordinates")
767
768
        # Get the desired methods
769
        method_name = "2".join([in_coord, out_coord])
770
        user_method = getattr(self.apex_out, method_name)
771
772
        # Get the user output
773
        user_out = user_method(*method_args)
774
775
        # Evaluate the user output
776
        for out_val in user_out:
777
            assert hasattr(out_val, 'shape'), "output coordinate isn't np.array"
778
            assert out_val.shape == out_shape
779
        return
780
781
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
782
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
783
    @pytest.mark.parametrize("bad_lat", [91, -91])
784
    def test_method_invalid_lat(self, in_coord, out_coord, bad_lat):
785
        """Test convert raises ValueError for invalid latitudes.
786
787
        Parameters
788
        ----------
789
        in_coord : str
790
            Input coordiante system
791
        out_coord : str
792
            Output coordiante system
793
        bad_lat : int
794
            Latitude in degrees N that is out of bounds
795
796
        """
797
        if in_coord == out_coord:
798
            pytest.skip("Test not needed for same src and dest coordinates")
799
800
        # Get the desired methods
801
        method_name = "2".join([in_coord, out_coord])
802
        user_method = getattr(self.apex_out, method_name)
803
804
        with pytest.raises(ValueError) as verr:
805
            user_method(bad_lat, 15, 100)
806
807
        assert str(verr.value).find("must be in [-90, 90]") > 0
808
        return
809
810
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
811
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
812
    @pytest.mark.parametrize("bound_lat", [90, -90])
813
    def test_method_at_lat_boundary(self, in_coord, out_coord, bound_lat):
814
        """Test user methods at the latitude boundary, with allowed excess.
815
816
        Parameters
817
        ----------
818
        in_coord : str
819
            Input coordiante system
820
        out_coord : str
821
            Output coordiante system
822
        bad_lat : int
823
            Latitude in degrees N that is at the limits of the boundary
824
825
        """
826
        if in_coord == out_coord:
827
            pytest.skip("Test not needed for same src and dest coordinates")
828
829
        # Get the desired methods
830
        method_name = "2".join([in_coord, out_coord])
831
        user_method = getattr(self.apex_out, method_name)
832
833
        # Get a latitude just beyond the limit
834
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
835
836
        # Get the two outputs, slight tolerance outside of boundary allowed
837
        bound_out = user_method(bound_lat, 0, 100)
838
        excess_out = user_method(excess_lat, 0, 100)
839
840
        # Test the outputs
841
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
842
        return
843
844
    def test_geo2apex_undefined_warning(self):
845
        """Test geo2apex warning and fill values for an undefined location."""
846
847
        # Update the apex object
848
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
849
850
        # Get the output and the warnings
851
        with warnings.catch_warnings(record=True) as warn_rec:
852
            user_lat, user_lon = self.apex_out.geo2apex(0, 0, 0)
853
854
        assert np.isnan(user_lat)
855
        assert np.isfinite(user_lon)
856
        assert len(warn_rec) == 1
857
        assert issubclass(warn_rec[-1].category, UserWarning)
858
        assert 'latitude set to NaN where' in str(warn_rec[-1].message)
859
        return
860
861
    @pytest.mark.parametrize("method_name", ["apex2qd", "qd2apex"])
862
    @pytest.mark.parametrize("delta_h", [1.0e-6, -1.0e-6])
863
    def test_quasidipole_apexheight_close(self, method_name, delta_h):
864
        """Test quasi-dipole success with a height close to the reference.
865
866
        Parameters
867
        ----------
868
        method_name : str
869
            Apex class method name to be tested
870
        delta_h : float
871
            tolerance for height in km
872
873
        """
874
        qd_method = getattr(self.apex_out, method_name)
875
        in_args = [0, 15, self.apex_out.refh + delta_h]
876
        out_coords = qd_method(*in_args)
877
878
        for i, out_val in enumerate(out_coords):
879
            np.testing.assert_almost_equal(out_val, in_args[i], decimal=3)
880
        return
881
882
    @pytest.mark.parametrize("method_name, hinc, msg",
883
                             [("apex2qd", 1.0, "is > apex height"),
884
                              ("qd2apex", -1.0, "is < reference height")])
885
    def test_quasidipole_raises_apexheight(self, method_name, hinc, msg):
886
        """Quasi-dipole raises ApexHeightError when height above reference.
887
888
        Parameters
889
        ----------
890
        method_name : str
891
            Apex class method name to be tested
892
        hinc : float
893
            Height increment in km
894
        msg : str
895
            Expected output message
896
897
        """
898
        qd_method = getattr(self.apex_out, method_name)
899
900
        with pytest.raises(apexpy.ApexHeightError) as aerr:
901
            qd_method(0, 15, self.apex_out.refh + hinc)
902
903
        assert str(aerr).find(msg) > 0
904
        return
905
906
907
class TestApexMLTMethods(object):
908
    """Test the Apex Magnetic Local Time (MLT) methods."""
909
    def setup_method(self):
910
        """Initialize all tests."""
911
        self.apex_out = apexpy.Apex(date=2000, refh=300)
912
        self.in_time = dt.datetime(2000, 2, 3, 4, 5, 6)
913
914
    def teardown_method(self):
915
        """Clean up after each test."""
916
        del self.apex_out, self.in_time
917
918
    @pytest.mark.parametrize("in_coord", ["geo", "apex", "qd"])
919
    def test_convert_to_mlt(self, in_coord):
920
        """Test the conversions to MLT using Apex convert.
921
922
        Parameters
923
        ----------
924
        in_coord : str
925
            Input coordinate system
926
927
        """
928
929
        # Get the magnetic longitude from the appropriate method
930
        if in_coord == "geo":
931
            apex_method = getattr(self.apex_out, "{:s}2apex".format(in_coord))
932
            mlon = apex_method(60, 15, 100)[1]
933
        else:
934
            mlon = 15
935
936
        # Get the output MLT values
937
        convert_mlt = self.apex_out.convert(60, 15, in_coord, 'mlt',
938
                                            height=100, ssheight=2e5,
939
                                            datetime=self.in_time)[1]
940
        method_mlt = self.apex_out.mlon2mlt(mlon, self.in_time, ssheight=2e5)
941
942
        # Test the outputs
943
        np.testing.assert_allclose(convert_mlt, method_mlt)
944
        return
945
946
    @pytest.mark.parametrize("out_coord", ["geo", "apex", "qd"])
947
    def test_convert_mlt_to_lon(self, out_coord):
948
        """Test the conversions from MLT using Apex convert.
949
950
        Parameters
951
        ----------
952
        out_coord : str
953
            Output coordinate system
954
955
        """
956
        # Get the output longitudes
957
        convert_out = self.apex_out.convert(60, 15, 'mlt', out_coord,
958
                                            height=100, ssheight=2e5,
959
                                            datetime=self.in_time,
960
                                            precision=1e-2)
961
        mlon = self.apex_out.mlt2mlon(15, self.in_time, ssheight=2e5)
962
963
        if out_coord == "geo":
964
            method_out = self.apex_out.apex2geo(60, mlon, 100,
965
                                                precision=1e-2)[:-1]
966
        elif out_coord == "qd":
967
            method_out = self.apex_out.apex2qd(60, mlon, 100)
968
        else:
969
            method_out = (60, mlon)
970
971
        # Evaluate the outputs
972
        np.testing.assert_allclose(convert_out, method_out)
973
        return
974
975
    def test_convert_geo2mlt_nodate(self):
976
        """Test convert from geo to MLT raises ValueError with no datetime."""
977
        with pytest.raises(ValueError):
978
            self.apex_out.convert(60, 15, 'geo', 'mlt')
979
        return
980
981
    @pytest.mark.parametrize("mlon_kwargs,test_mlt",
982
                             [({}, 23.019629923502603),
983
                              ({"ssheight": 100000}, 23.026712036132814)])
984
    def test_mlon2mlt_scalar_inputs(self, mlon_kwargs, test_mlt):
985
        """Test mlon2mlt with scalar inputs.
986
987
        Parameters
988
        ----------
989
        mlon_kwargs : dict
990
            Input kwargs
991
        test_mlt : float
992
            Output MLT in hours
993
994
        """
995
        mlt = self.apex_out.mlon2mlt(0, self.in_time, **mlon_kwargs)
996
997
        np.testing.assert_allclose(mlt, test_mlt)
998
        assert np.asarray(mlt).shape == ()
999
        return
1000
1001
    @pytest.mark.parametrize("mlt_kwargs,test_mlon",
1002
                             [({}, 14.705535888671875),
1003
                              ({"ssheight": 100000}, 14.599319458007812)])
1004
    def test_mlt2mlon_scalar_inputs(self, mlt_kwargs, test_mlon):
1005
        """Test mlt2mlon with scalar inputs.
1006
1007
        Parameters
1008
        ----------
1009
        mlt_kwargs : dict
1010
            Input kwargs
1011
        test_mlon : float
1012
            Output longitude in degrees E
1013
1014
        """
1015
        mlon = self.apex_out.mlt2mlon(0, self.in_time, **mlt_kwargs)
1016
1017
        np.testing.assert_allclose(mlon, test_mlon)
1018
        assert np.asarray(mlon).shape == ()
1019
        return
1020
1021
    @pytest.mark.parametrize("mlon,test_mlt",
1022
                             [([0, 180], [23.019261, 11.019261]),
1023
                              (np.array([0, 180]), [23.019261, 11.019261]),
1024
                              (np.array([[0], [180]]),
1025
                               np.array([[23.019261], [11.019261]])),
1026
                              ([[0, 180], [0, 180]], [[23.019261, 11.019261],
1027
                                                      [23.019261, 11.019261]]),
1028
                              (range(0, 361, 30),
1029
                               [23.01963, 1.01963, 3.01963, 5.01963, 7.01963,
1030
                                9.01963, 11.01963, 13.01963, 15.01963, 17.01963,
1031
                                19.01963, 21.01963, 23.01963])])
1032
    def test_mlon2mlt_array(self, mlon, test_mlt):
1033
        """Test mlon2mlt with array inputs.
1034
1035
        Parameters
1036
        ----------
1037
        mlon : array-like
1038
            Input longitudes in degrees E
1039
        test_mlt : float
1040
            Output MLT in hours
1041
1042
        """
1043
        mlt = self.apex_out.mlon2mlt(mlon, self.in_time)
1044
1045
        assert mlt.shape == np.asarray(test_mlt).shape
1046
        np.testing.assert_allclose(mlt, test_mlt, rtol=1e-4)
1047
        return
1048
1049
    @pytest.mark.parametrize("mlt,test_mlon",
1050
                             [([0, 12], [14.705551, 194.705551]),
1051
                              (np.array([0, 12]), [14.705551, 194.705551]),
1052
                              (np.array([[0], [12]]),
1053
                               np.array([[14.705551], [194.705551]])),
1054
                              ([[0, 12], [0, 12]], [[14.705551, 194.705551],
1055
                                                    [14.705551, 194.705551]]),
1056
                              (range(0, 25, 2),
1057
                               [14.705551, 44.705551, 74.705551, 104.705551,
1058
                                134.705551, 164.705551, 194.705551, 224.705551,
1059
                                254.705551, 284.705551, 314.705551, 344.705551,
1060
                                14.705551])])
1061
    def test_mlt2mlon_array(self, mlt, test_mlon):
1062
        """Test mlt2mlon with array inputs.
1063
1064
        Parameters
1065
        ----------
1066
        mlt : array-like
1067
            Input MLT in hours
1068
        test_mlon : float
1069
            Output longitude in degrees E
1070
1071
        """
1072
        mlon = self.apex_out.mlt2mlon(mlt, self.in_time)
1073
1074
        assert mlon.shape == np.asarray(test_mlon).shape
1075
        np.testing.assert_allclose(mlon, test_mlon, rtol=1e-4)
1076
        return
1077
1078
    @pytest.mark.parametrize("method_name", ["mlon2mlt", "mlt2mlon"])
1079
    def test_mlon2mlt_diffdates(self, method_name):
1080
        """Test that MLT varies with universal time.
1081
1082
        Parameters
1083
        ----------
1084
        method_name : str
1085
            Name of Apex class method to be tested
1086
1087
        """
1088
        apex_method = getattr(self.apex_out, method_name)
1089
        mlt1 = apex_method(0, self.in_time)
1090
        mlt2 = apex_method(0, self.in_time + dt.timedelta(hours=1))
1091
1092
        assert mlt1 != mlt2
1093
        return
1094
1095
    @pytest.mark.parametrize("mlt_offset", [1.0, 10.0])
1096
    def test_mlon2mlt_offset(self, mlt_offset):
1097
        """Test the time wrapping logic for the MLT.
1098
1099
        Parameters
1100
        ----------
1101
        mlt_offset : float
1102
            MLT offset in hours
1103
1104
        """
1105
        mlt1 = self.apex_out.mlon2mlt(0.0, self.in_time)
1106
        mlt2 = self.apex_out.mlon2mlt(-15.0 * mlt_offset,
1107
                                      self.in_time) + mlt_offset
1108
1109
        np.testing.assert_allclose(mlt1, mlt2)
1110
        return
1111
1112
    @pytest.mark.parametrize("mlon_offset", [15.0, 150.0])
1113
    def test_mlt2mlon_offset(self, mlon_offset):
1114
        """Test the time wrapping logic for the magnetic longitude.
1115
1116
        Parameters
1117
        ----------
1118
        mlt_offset : float
1119
            MLT offset in hours
1120
1121
        """
1122
        mlon1 = self.apex_out.mlt2mlon(0, self.in_time)
1123
        mlon2 = self.apex_out.mlt2mlon(mlon_offset / 15.0,
1124
                                       self.in_time) - mlon_offset
1125
1126
        np.testing.assert_allclose(mlon1, mlon2)
1127
        return
1128
1129
    @pytest.mark.parametrize("order", [["mlt", "mlon"], ["mlon", "mlt"]])
1130
    @pytest.mark.parametrize("start_val", [0, 6, 12, 18, 22])
1131
    def test_convert_and_return(self, order, start_val):
1132
        """Test the conversion to magnetic longitude or MLT and back again.
1133
1134
        Parameters
1135
        ----------
1136
        order : list
1137
            List of strings specifying the order to run functions
1138
        start_val : int or float
1139
            Input value
1140
1141
        """
1142
        first_method = getattr(self.apex_out, "2".join(order))
1143
        second_method = getattr(self.apex_out, "2".join([order[1], order[0]]))
1144
1145
        middle_val = first_method(start_val, self.in_time)
1146
        end_val = second_method(middle_val, self.in_time)
1147
1148
        np.testing.assert_allclose(start_val, end_val)
1149
        return
1150
1151
1152
class TestApexMapMethods(object):
1153
    """Test the Apex height mapping methods."""
1154
    def setup_method(self):
1155
        """Initialize all tests."""
1156
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1157
1158
    def teardown_method(self):
1159
        """Clean up after each test."""
1160
        del self.apex_out
1161
1162
    @pytest.mark.parametrize("in_args,test_mapped",
1163
                             [([60, 15, 100, 10000],
1164
                               [31.841466903686523, 17.916635513305664,
1165
                                1.7075473124350538e-6]),
1166
                              ([30, 170, 100, 500, False, 1e-2],
1167
                               [25.727270126342773, 169.60546875,
1168
                                0.00017573432705830783]),
1169
                              ([60, 15, 100, 10000, True],
1170
                               [-25.424888610839844, 27.310426712036133,
1171
                                1.2074182222931995e-6]),
1172
                              ([30, 170, 100, 500, True, 1e-2],
1173
                               [-13.76642894744873, 164.24259948730469,
1174
                                0.00056820799363777041])])
1175
    def test_map_to_height(self, in_args, test_mapped):
1176
        """Test the map_to_height function.
1177
1178
        Parameters
1179
        ----------
1180
        in_args : list
1181
            List of input arguments
1182
        test_mapped : list
1183
            List of expected outputs
1184
1185
        """
1186
        mapped = self.apex_out.map_to_height(*in_args)
1187
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5, atol=1e-5)
1188
        return
1189
1190
    def test_map_to_height_same_height(self):
1191
        """Test the map_to_height function when mapping to same height."""
1192
        mapped = self.apex_out.map_to_height(60, 15, 100, 100, conjugate=False,
1193
                                             precision=1e-10)
1194
        np.testing.assert_allclose(mapped, (60.0, 15.000003814697266, 0.0),
1195
                                   rtol=1e-5, atol=1e-5)
1196
        return
1197
1198
    @pytest.mark.parametrize('arr_shape', [(2,), (2, 2), (1, 4)])
1199
    @pytest.mark.parametrize('ivec', range(0, 4))
1200
    def test_map_to_height_array_location(self, arr_shape, ivec):
1201
        """Test map_to_height with array input.
1202
1203
        Parameters
1204
        ----------
1205
        arr_shape : tuple
1206
            Expected array shape
1207
        ivec : int
1208
            Input argument index for vectorized input
1209
1210
        """
1211
        # Set the base input and output values
1212
        in_args = [60, 15, 100, 100]
1213
        test_mapped = [60, 15.00000381, 0.0]
1214
1215
        # Update inputs for one vectorized value
1216
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1217
1218
        # Calculate and test function
1219
        mapped = self.apex_out.map_to_height(*in_args)
1220
        for i, test_val in enumerate(test_mapped):
1221
            assert mapped[i].shape == arr_shape
1222
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5,
1223
                                       atol=1e-5)
1224
        return
1225
1226
    @pytest.mark.parametrize("method_name,in_args",
1227
                             [("map_to_height", [0, 15, 100, 10000]),
1228
                              ("map_E_to_height",
1229
                               [0, 15, 100, 10000, [1, 2, 3]]),
1230
                              ("map_V_to_height",
1231
                               [0, 15, 100, 10000, [1, 2, 3]])])
1232
    def test_mapping_height_raises_ApexHeightError(self, method_name, in_args):
1233
        """Test map_to_height raises ApexHeightError.
1234
1235
        Parameters
1236
        ----------
1237
        method_name : str
1238
            Name of the Apex class method to test
1239
        in_args : list
1240
            List of input arguments
1241
1242
        """
1243
        apex_method = getattr(self.apex_out, method_name)
1244
1245
        with pytest.raises(apexpy.ApexHeightError) as aerr:
1246
            apex_method(*in_args)
1247
1248
        assert aerr.match("is > apex height")
1249
        return
1250
1251
    @pytest.mark.parametrize("method_name",
1252
                             ["map_E_to_height", "map_V_to_height"])
1253
    @pytest.mark.parametrize("ev_input", [([1, 2, 3, 4, 5]),
1254
                                          ([[1, 2], [3, 4], [5, 6], [7, 8]])])
1255
    def test_mapping_EV_bad_shape(self, method_name, ev_input):
1256
        """Test height mapping of E/V with baddly shaped input raises Error.
1257
1258
        Parameters
1259
        ----------
1260
        method_name : str
1261
            Name of the Apex class method to test
1262
        ev_input : list
1263
            E/V input arguments
1264
1265
        """
1266
        apex_method = getattr(self.apex_out, method_name)
1267
        in_args = [60, 15, 100, 500, ev_input]
1268
        with pytest.raises(ValueError) as verr:
1269
            apex_method(*in_args)
1270
1271
        assert str(verr.value).find("must be (3, N) or (3,) ndarray") >= 0
1272
        return
1273
1274
    def test_mapping_EV_bad_flag(self):
1275
        """Test _map_EV_to_height raises error for bad data type flag."""
1276
        with pytest.raises(ValueError) as verr:
1277
            self.apex_out._map_EV_to_height(60, 15, 100, 500, [1, 2, 3], "P")
1278
1279
        assert str(verr.value).find("unknown electric field/drift flag") >= 0
1280
        return
1281
1282
    @pytest.mark.parametrize("in_args,test_mapped",
1283
                             [([60, 15, 100, 500, [1, 2, 3]],
1284
                               [0.71152183, 2.35624876, 0.57260784]),
1285
                              ([60, 15, 100, 500, [2, 3, 4]],
1286
                               [1.56028502, 3.43916636, 0.78235384]),
1287
                              ([60, 15, 100, 1000, [1, 2, 3]],
1288
                               [0.67796492, 2.08982134, 0.55860785]),
1289
                              ([60, 15, 200, 500, [1, 2, 3]],
1290
                               [0.72377397, 2.42737471, 0.59083726]),
1291
                              ([60, 30, 100, 500, [1, 2, 3]],
1292
                               [0.68626344, 2.37530133, 0.60060124]),
1293
                              ([70, 15, 100, 500, [1, 2, 3]],
1294
                               [0.72760378, 2.18082305, 0.29141979])])
1295
    def test_map_E_to_height_scalar_location(self, in_args, test_mapped):
1296
        """Test mapping of E-field to a specified height.
1297
1298
        Parameters
1299
        ----------
1300
        in_args : list
1301
            List of input arguments
1302
        test_mapped : list
1303
            List of expected outputs
1304
1305
        """
1306
        mapped = self.apex_out.map_E_to_height(*in_args)
1307
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1308
        return
1309
1310
    @pytest.mark.parametrize('ev_flag, test_mapped',
1311
                             [('E', [0.71152183, 2.35624876, 0.57260784]),
1312
                              ('V', [0.81971957, 2.84512495, 0.69545001])])
1313
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1314
    @pytest.mark.parametrize('ivec', range(0, 5))
1315
    def test_map_EV_to_height_array_location(self, ev_flag, test_mapped,
1316
                                             arr_shape, ivec):
1317
        """Test mapping of E-field/drift to a specified height with arrays.
1318
1319
        Parameters
1320
        ----------
1321
        ev_flag : str
1322
            Character flag specifying whether to run 'E' or 'V' methods
1323
        test_mapped : list
1324
            List of expected outputs
1325
        arr_shape : tuple
1326
            Shape of the expected output
1327
        ivec : int
1328
            Index of the expected output
1329
1330
        """
1331
        # Set the base input and output values
1332
        eshape = list(arr_shape)
1333
        eshape.insert(0, 3)
1334
        edata = np.array([[1, 2, 3]] * np.product(arr_shape)).transpose()
1335
        in_args = [60, 15, 100, 500, edata.reshape(tuple(eshape))]
1336
1337
        # Update inputs for one vectorized value if this is a location input
1338
        if ivec < 4:
1339
            in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1340
1341
        # Get the mapped output
1342
        apex_method = getattr(self.apex_out,
1343
                              "map_{:s}_to_height".format(ev_flag))
1344
        mapped = apex_method(*in_args)
1345
1346
        # Test the results
1347
        for i, test_val in enumerate(test_mapped):
1348
            assert mapped[i].shape == arr_shape
1349
            np.testing.assert_allclose(mapped[i], test_val, rtol=1e-5)
1350
        return
1351
1352
    @pytest.mark.parametrize("in_args,test_mapped",
1353
                             [([60, 15, 100, 500, [1, 2, 3]],
1354
                               [0.81971957, 2.84512495, 0.69545001]),
1355
                              ([60, 15, 100, 500, [2, 3, 4]],
1356
                               [1.83027746, 4.14346436, 0.94764179]),
1357
                              ([60, 15, 100, 1000, [1, 2, 3]],
1358
                               [0.92457698, 3.14997661, 0.85135187]),
1359
                              ([60, 15, 200, 500, [1, 2, 3]],
1360
                               [0.80388262, 2.79321504, 0.68285158]),
1361
                              ([60, 30, 100, 500, [1, 2, 3]],
1362
                               [0.76141245, 2.87884673, 0.73655941]),
1363
                              ([70, 15, 100, 500, [1, 2, 3]],
1364
                               [0.84681866, 2.5925821, 0.34792655])])
1365
    def test_map_V_to_height_scalar_location(self, in_args, test_mapped):
1366
        """Test mapping of velocity to a specified height.
1367
1368
        Parameters
1369
        ----------
1370
        in_args : list
1371
            List of input arguments
1372
        test_mapped : list
1373
            List of expected outputs
1374
1375
        """
1376
        mapped = self.apex_out.map_V_to_height(*in_args)
1377
        np.testing.assert_allclose(mapped, test_mapped, rtol=1e-5)
1378
        return
1379
1380
1381
class TestApexBasevectorMethods(object):
1382
    """Test the Apex height base vector methods."""
1383
    def setup_method(self):
1384
        """Initialize all tests."""
1385
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1386
        self.lat = 60
1387
        self.lon = 15
1388
        self.height = 100
1389
        self.test_basevec = None
1390
1391
    def teardown_method(self):
1392
        """Clean up after each test."""
1393
        del self.apex_out, self.test_basevec, self.lat, self.lon, self.height
1394
1395
    def get_comparison_results(self, bv_coord, coords, precision):
1396
        """Get the base vector results using the hidden function for comparison.
1397
1398
        Parameters
1399
        ----------
1400
        bv_coord : str
1401
            Basevector coordinate scheme, expects on of 'apex', 'qd',
1402
            or 'bvectors_apex'
1403
        coords : str
1404
            Expects one of 'geo', 'apex', or 'qd'
1405
        precision : float
1406
            Float specifiying precision
1407
1408
        """
1409
        if coords == "geo":
1410
            glat = self.lat
1411
            glon = self.lon
1412
        else:
1413
            apex_method = getattr(self.apex_out, "{:s}2geo".format(coords))
1414
            glat, glon, _ = apex_method(self.lat, self.lon, self.height,
1415
                                        precision=precision)
1416
1417
        if bv_coord == 'qd':
1418
            self.test_basevec = self.apex_out._basevec(glat, glon, self.height)
1419
        elif bv_coord == 'apex':
1420
            (_, _, _, _, f1, f2, _, d1, d2, d3, _, e1, e2,
1421
             e3) = self.apex_out._geo2apexall(glat, glon, 100)
1422
            self.test_basevec = (f1, f2, d1, d2, d3, e1, e2, e3)
1423
        else:
1424
            # These are set results that need to be updated with IGRF
1425
            if coords == "geo":
1426
                self.test_basevec = (
1427
                    np.array([4.42368795e-05, 4.42368795e-05]),
1428
                    np.array([[0.01047826, 0.01047826],
1429
                              [0.33089194, 0.33089194],
1430
                              [-1.04941, -1.04941]]),
1431
                    np.array([5.3564698e-05, 5.3564698e-05]),
1432
                    np.array([[0.00865356, 0.00865356],
1433
                              [0.27327004, 0.27327004],
1434
                              [-0.8666646, -0.8666646]]))
1435
            elif coords == "apex":
1436
                self.test_basevec = (
1437
                    np.array([4.48672735e-05, 4.48672735e-05]),
1438
                    np.array([[-0.12510721, -0.12510721],
1439
                              [0.28945938, 0.28945938],
1440
                              [-1.1505738, -1.1505738]]),
1441
                    np.array([6.38577444e-05, 6.38577444e-05]),
1442
                    np.array([[-0.08790194, -0.08790194],
1443
                              [0.2033779, 0.2033779],
1444
                              [-0.808408, -0.808408]]))
1445
            else:
1446
                self.test_basevec = (
1447
                    np.array([4.46348578e-05, 4.46348578e-05]),
1448
                    np.array([[-0.12642345, -0.12642345],
1449
                              [0.29695055, 0.29695055],
1450
                              [-1.1517885, -1.1517885]]),
1451
                    np.array([6.38626285e-05, 6.38626285e-05]),
1452
                    np.array([[-0.08835986, -0.08835986],
1453
                              [0.20754464, 0.20754464],
1454
                              [-0.8050078, -0.8050078]]))
1455
1456
        return
1457
1458
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1459
    @pytest.mark.parametrize("coords,precision",
1460
                             [("geo", 1e-10), ("apex", 1.0e-2), ("qd", 1.0e-2)])
1461
    def test_basevectors_scalar(self, bv_coord, coords, precision):
1462
        """Test the base vector calculations with scalars.
1463
1464
        Parameters
1465
        ----------
1466
        bv_coord : str
1467
            Name of the input coordinate system
1468
        coords : str
1469
            Name of the output coordinate system
1470
        precision : float
1471
            Level of run precision requested
1472
1473
        """
1474
        # Get the base vectors
1475
        base_method = getattr(self.apex_out,
1476
                              "basevectors_{:s}".format(bv_coord))
1477
        basevec = base_method(self.lat, self.lon, self.height, coords=coords,
1478
                              precision=precision)
1479
        self.get_comparison_results(bv_coord, coords, precision)
1480
        if bv_coord == "apex":
1481
            basevec = list(basevec)
1482
            for i in range(4):
1483
                # Not able to compare indices 2, 3, 4, and 5
1484
                basevec.pop(2)
1485
1486
        # Test the results
1487
        for i, vec in enumerate(basevec):
1488
            np.testing.assert_allclose(vec, self.test_basevec[i])
1489
        return
1490
1491
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1492
    def test_basevectors_scalar_shape(self, bv_coord):
1493
        """Test the shape of the scalar output.
1494
1495
        Parameters
1496
        ----------
1497
        bv_coord : str
1498
            Name of the input coordinate system
1499
1500
        """
1501
        base_method = getattr(self.apex_out,
1502
                              "basevectors_{:s}".format(bv_coord))
1503
        basevec = base_method(self.lat, self.lon, self.height)
1504
1505
        for i, vec in enumerate(basevec):
1506
            if i < 2:
1507
                assert vec.shape == (2,)
1508
            else:
1509
                assert vec.shape == (3,)
1510
        return
1511
1512
    @pytest.mark.parametrize('arr_shape', [(2,), (5,)])
1513
    @pytest.mark.parametrize("bv_coord", ["qd", "apex"])
1514
    @pytest.mark.parametrize("ivec", range(3))
1515
    def test_basevectors_array(self, arr_shape, bv_coord, ivec):
1516
        """Test the output shape for array inputs.
1517
1518
        Parameters
1519
        ----------
1520
        arr_shape : tuple
1521
            Expected output shape
1522
        bv_coord : str
1523
            Name of the input coordinate system
1524
        ivec : int
1525
            Index of the evaluated output value
1526
1527
        """
1528
        # Define the input arguments
1529
        in_args = [self.lat, self.lon, self.height]
1530
        in_args[ivec] = np.full(shape=arr_shape, fill_value=in_args[ivec])
1531
1532
        # Get the basevectors
1533
        base_method = getattr(self.apex_out,
1534
                              "basevectors_{:s}".format(bv_coord))
1535
        basevec = base_method(*in_args, coords='geo', precision=1e-10)
1536
        self.get_comparison_results(bv_coord, "geo", 1e-10)
1537
        if bv_coord == "apex":
1538
            basevec = list(basevec)
1539
            for i in range(4):
1540
                # Not able to compare indices 2, 3, 4, and 5
1541
                basevec.pop(2)
1542
1543
        # Evaluate the shape and the values
1544
        for i, vec in enumerate(basevec):
1545
            test_shape = list(arr_shape)
1546
            test_shape.insert(0, 2 if i < 2 else 3)
1547
            assert vec.shape == tuple(test_shape)
1548
            assert np.all(self.test_basevec[i][0] == vec[0])
1549
            assert np.all(self.test_basevec[i][1] == vec[1])
1550
        return
1551
1552
    @pytest.mark.parametrize("coords", ["geo", "apex", "qd"])
1553
    def test_bvectors_apex(self, coords):
1554
        """Test the bvectors_apex method.
1555
1556
        Parameters
1557
        ----------
1558
        coords : str
1559
            Name of the coordiante system
1560
1561
        """
1562
        in_args = [[self.lat, self.lat], [self.lon, self.lon],
1563
                   [self.height, self.height]]
1564
        self.get_comparison_results("bvectors_apex", coords, 1e-10)
1565
1566
        basevec = self.apex_out.bvectors_apex(*in_args, coords=coords,
1567
                                              precision=1e-10)
1568
        for i, vec in enumerate(basevec):
1569
            np.testing.assert_array_almost_equal(vec, self.test_basevec[i],
1570
                                                 decimal=5)
1571
        return
1572
1573
    def test_basevectors_apex_extra_values(self):
1574
        """Test specific values in the apex base vector output."""
1575
        # Set the testing arrays
1576
        self.test_basevec = [np.array([0.092637, -0.245951, 0.938848]),
1577
                             np.array([0.939012, 0.073416, -0.07342]),
1578
                             np.array([0.055389, 1.004155, 0.257594]),
1579
                             np.array([0, 0, 1.065135])]
1580
1581
        # Get the desired output
1582
        basevec = self.apex_out.basevectors_apex(0, 15, 100, coords='geo')
1583
1584
        # Test the values not covered by `test_basevectors_scalar`
1585
        for itest, ibase in enumerate(np.arange(2, 6, 1)):
1586
            np.testing.assert_allclose(basevec[ibase],
1587
                                       self.test_basevec[itest], rtol=1e-4)
1588
        return
1589
1590
    @pytest.mark.parametrize("lat", range(0, 90, 10))
1591
    @pytest.mark.parametrize("lon", range(0, 360, 15))
1592
    def test_basevectors_apex_delta(self, lat, lon):
1593
        """Test that vectors are calculated correctly.
1594
1595
        Parameters
1596
        ----------
1597
        lat : int or float
1598
            Latitude in degrees N
1599
        lon : int or float
1600
            Longitude in degrees E
1601
1602
        """
1603
        # Get the apex base vectors and sort them for easy testing
1604
        (f1, f2, f3, g1, g2, g3, d1, d2, d3, e1, e2,
1605
         e3) = self.apex_out.basevectors_apex(lat, lon, 500)
1606
        fvec = [np.append(f1, 0), np.append(f2, 0), f3]
1607
        gvec = [g1, g2, g3]
1608
        dvec = [d1, d2, d3]
1609
        evec = [e1, e2, e3]
1610
1611
        for idelta, jdelta in [(i, j) for i in range(3) for j in range(3)]:
1612
            delta = 1 if idelta == jdelta else 0
1613
            np.testing.assert_allclose(np.sum(fvec[idelta] * gvec[jdelta]),
1614
                                       delta, rtol=0, atol=1e-5)
1615
            np.testing.assert_allclose(np.sum(dvec[idelta] * evec[jdelta]),
1616
                                       delta, rtol=0, atol=1e-5)
1617
        return
1618
1619
    def test_basevectors_apex_invalid_scalar(self):
1620
        """Test warning and fill values for base vectors with bad inputs."""
1621
        self.apex_out = apexpy.Apex(date=2000, refh=10000)
1622
        invalid = np.full(shape=(3,), fill_value=np.nan)
1623
1624
        # Get the output and the warnings
1625
        with warnings.catch_warnings(record=True) as warn_rec:
1626
            basevec = self.apex_out.basevectors_apex(0, 0, 0)
1627
1628
        for i, bvec in enumerate(basevec):
1629
            if i < 2:
1630
                assert not np.allclose(bvec, invalid[:2])
1631
            else:
1632
                np.testing.assert_allclose(bvec, invalid)
1633
1634
        assert issubclass(warn_rec[-1].category, UserWarning)
1635
        assert 'set to NaN where' in str(warn_rec[-1].message)
1636
        return
1637
1638
1639
class TestApexGetMethods(object):
1640
    """Test the Apex `get` methods."""
1641
    def setup_method(self):
1642
        """Initialize all tests."""
1643
        self.apex_out = apexpy.Apex(date=2000, refh=300)
1644
1645
    def teardown_method(self):
1646
        """Clean up after each test."""
1647
        del self.apex_out
1648
1649
    @pytest.mark.parametrize("alat, aheight",
1650
                             [(10, 507.409702543805),
1651
                              (60, 20313.026999999987),
1652
                              ([10, 60],
1653
                               [507.409702543805, 20313.026999999987]),
1654
                              ([[10], [60]],
1655
                               [[507.409702543805], [20313.026999999987]])])
1656
    def test_get_apex(self, alat, aheight):
1657
        """Test the apex height retrieval results.
1658
1659
        Parameters
1660
        ----------
1661
        alat : int or float
1662
            Apex latitude in degrees N
1663
        aheight : int or float
1664
            Apex height in km
1665
1666
        """
1667
        alt = self.apex_out.get_apex(alat)
1668
        np.testing.assert_allclose(alt, aheight)
1669
        return
1670
1671
    @pytest.mark.parametrize("glat,glon,height,test_bmag",
1672
                             [([80], [100], [300], 5.100682377815247e-05),
1673
                              ([80, 80], [100], [300],
1674
                               [5.100682377815247e-05, 5.100682377815247e-05]),
1675
                              ([[80], [80]], [100], [300],
1676
                               [[5.100682377815247e-05],
1677
                                [5.100682377815247e-05]]),
1678
                              (range(50, 90, 8), range(0, 360, 80), [300] * 5,
1679
                               np.array([4.18657154e-05, 5.11118114e-05,
1680
                                         4.91969854e-05, 5.10519207e-05,
1681
                                         4.90054816e-05])),
1682
                              (90.0, 0, 1000, 3.7834718823432923e-05)])
1683
    def test_get_babs(self, glat, glon, height, test_bmag):
1684
        """Test the method to get the magnitude of the magnetic field.
1685
1686
        Parameters
1687
        ----------
1688
        glat : list
1689
            List of latitudes in degrees N
1690
        glon : list
1691
            List of longitudes in degrees E
1692
        height : list
1693
            List of heights in km
1694
        test_bmag : float
1695
            Expected B field magnitude
1696
1697
        """
1698
        bmag = self.apex_out.get_babs(glat, glon, height)
1699
        np.testing.assert_allclose(bmag, test_bmag, rtol=0, atol=1e-5)
1700
        return
1701
1702
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1703
    def test_get_apex_with_invalid_lat(self, bad_lat):
1704
        """Test get methods raise ValueError for invalid latitudes.
1705
1706
        Parameters
1707
        ----------
1708
        bad_lat : int or float
1709
            Bad input latitude in degrees N
1710
1711
        """
1712
1713
        with pytest.raises(ValueError) as verr:
1714
            self.apex_out.get_apex(bad_lat)
1715
1716
        assert str(verr.value).find("must be in [-90, 90]") > 0
1717
        return
1718
1719
    @pytest.mark.parametrize("bad_lat", [(91), (-91)])
1720
    def test_get_babs_with_invalid_lat(self, bad_lat):
1721
        """Test get methods raise ValueError for invalid latitudes.
1722
1723
        Parameters
1724
        ----------
1725
        bad_lat : int or float
1726
            Bad input latitude in degrees N
1727
1728
        """
1729
1730
        with pytest.raises(ValueError) as verr:
1731
            self.apex_out.get_babs(bad_lat, 15, 100)
1732
1733
        assert str(verr.value).find("must be in [-90, 90]") > 0
1734
        return
1735
1736
    @pytest.mark.parametrize("bound_lat", [(90), (-90)])
1737
    def test_get_at_lat_boundary(self, bound_lat):
1738
        """Test get methods at the latitude boundary, with allowed excess.
1739
1740
        Parameters
1741
        ----------
1742
        bound_lat : int or float
1743
            Boundary input latitude in degrees N
1744
1745
        """
1746
        # Get a latitude just beyond the limit
1747
        excess_lat = np.sign(bound_lat) * (abs(bound_lat) + 1.0e-5)
1748
1749
        # Get the two outputs, slight tolerance outside of boundary allowed
1750
        bound_out = self.apex_out.get_apex(bound_lat)
1751
        excess_out = self.apex_out.get_apex(excess_lat)
1752
1753
        # Test the outputs
1754
        np.testing.assert_allclose(excess_out, bound_out, rtol=0, atol=1e-8)
1755
        return
1756
1757
    @pytest.mark.parametrize("apex_height", [-100, 0, 300, 10000])
1758
    def test_get_height_at_equator(self, apex_height):
1759
        """Test that `get_height` returns apex height at equator.
1760
1761
        Parameters
1762
        ----------
1763
        apex_height : float
1764
            Apex height
1765
1766
        """
1767
1768
        assert apex_height == self.apex_out.get_height(0.0, apex_height)
1769
        return
1770
1771
    @pytest.mark.parametrize("lat, height", [
1772
        (-90, -6371.009), (-80, -6088.438503309167), (-70, -5274.8091854339655),
1773
        (-60, -4028.256749999999), (-50, -2499.1338178752017),
1774
        (-40, -871.8751821247979), (-30, 657.2477500000014),
1775
        (-20, 1903.8001854339655), (-10, 2717.4295033091657), (0, 3000.0),
1776
        (10, 2717.4295033091657), (20, 1903.8001854339655),
1777
        (30, 657.2477500000014), (40, -871.8751821247979),
1778
        (50, -2499.1338178752017), (60, -4028.256749999999),
1779
        (70, -5274.8091854339655), (80, -6088.438503309167)])
1780
    def test_get_height_along_fieldline(self, lat, height):
1781
        """Test that `get_height` returns expected height of field line.
1782
1783
        Parameters
1784
        ----------
1785
        lat : float
1786
            Input latitude
1787
        height : float
1788
            Output field-line height for line with apex of 3000 km
1789
1790
        """
1791
1792
        fheight = self.apex_out.get_height(lat, 3000.0)
1793
        assert abs(height - fheight) < 1.0e-7, \
1794
            "bad height calculation: {:.7f} != {:.7f}".format(height, fheight)
1795
        return
1796
1797
1798
class TestApexMethodExtrapolateIGRF(object):
1799
    """Test the Apex methods on a year when IGRF must be extrapolated.
1800
1801
    Notes
1802
    -----
1803
    Extrapolation should be done using a year within 5 years of the latest IGRF
1804
    model epoch.
1805
1806
    """
1807
1808
    def setup_method(self):
1809
        """Initialize all tests."""
1810
        self.apex_out = apexpy.Apex(date=2025, refh=300)
1811
        self.in_lat = 60
1812
        self.in_lon = 15
1813
        self.in_alt = 100
1814
        self.in_time = dt.datetime(2024, 2, 3, 4, 5, 6)
1815
        return
1816
1817
    def teardown_method(self):
1818
        """Clean up after each test."""
1819
        del self.apex_out, self.in_lat, self.in_lon, self.in_alt
1820
        return
1821
1822 View Code Duplication
    @pytest.mark.parametrize("method_name, out_comp",
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1823
                             [("geo2apex",
1824
                               (56.25343704223633, 92.04932403564453)),
1825
                              ("apex2geo",
1826
                               (53.84184265136719, -66.93045806884766,
1827
                                3.6222547805664362e-06)),
1828
                              ("geo2qd",
1829
                               (56.82968521118164, 92.04932403564453)),
1830
                              ("apex2qd", (60.498401178276744, 15.0)),
1831
                              ("qd2apex", (59.49138097045895, 15.0))])
1832
    def test_method_scalar_input(self, method_name, out_comp):
1833
        """Test the user method against set values with scalars.
1834
1835
        Parameters
1836
        ----------
1837
        method_name : str
1838
            Apex class method to be tested
1839
        out_comp : tuple of floats
1840
            Expected output values
1841
1842
        """
1843
        # Get the desired methods
1844
        user_method = getattr(self.apex_out, method_name)
1845
1846
        # Get the user output
1847
        user_out = user_method(self.in_lat, self.in_lon, self.in_alt)
1848
1849
        # Evaluate the user output
1850
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1851
1852
        for out_val in user_out:
1853
            assert np.asarray(out_val).shape == (), "output is not a scalar"
1854
        return
1855
1856
    def test_convert_to_mlt(self):
1857
        """Test conversion from mlon to mlt with scalars."""
1858
1859
        # Get user output
1860
        user_out = self.apex_out.mlon2mlt(self.in_lon, self.in_time)
1861
1862
        # Set comparison values
1863
        out_comp = 23.955474853515625
1864
1865
        # Evaluate user output
1866
        np.testing.assert_allclose(user_out, out_comp, rtol=1e-5, atol=1e-5)
1867
        return
1868