Completed
Push — master ( 4fc32f...027495 )
by Andy
01:22
created

TestRegularizedCannonModel.test_init()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
dl 0
loc 2
rs 10
c 1
b 0
f 0
cc 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the Regularized Cannon model class and associated functions.
6
"""
7
8
import numpy as np
9
import unittest
10
from AnniesLasso import regularized, utils
11
12
13
class TestRegularizedCannonModel(unittest.TestCase):
14
15 View Code Duplication
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
16
        # Initialise some faux data and labels.
17
        labels = "ABCDE"
18
        N_labels = len(labels)
19
        N_stars = np.random.randint(1, 500)
20
        N_pixels = np.random.randint(1, 10000)
21
        shape = (N_stars, N_pixels)
22
23
        self.valid_training_labels = np.rec.array(
24
            np.random.uniform(size=(N_stars, N_labels)),
25
            dtype=[(label, '<f8') for label in labels])
26
27
        self.valid_fluxes = np.random.uniform(size=shape)
28
        self.valid_flux_uncertainties = np.random.uniform(size=shape)
29
30
    def get_model(self):
31
        return regularized.RegularizedCannonModel(
32
            self.valid_training_labels, self.valid_fluxes,
33
            self.valid_flux_uncertainties)
34
35
    def test_init(self):
36
        self.assertIsNotNone(self.get_model())
37
38
    def test_remind_myself_to_write_unit_tests_for_these_functions(self):
39
        m = self.get_model()
40
        m.label_vector = "A + B + C"
41
        self.assertIsNotNone(m.label_vector)
42
43
        # Cannot train without regularization term.
44
        with self.assertRaises(TypeError):
45
            m.train()
46
47
        # Regularization must be positive and finite.
48
        for each in (-1, np.nan, +np.inf, -np.inf):
49
            with self.assertRaises(ValueError):
50
                m.regularization = each
51
52
        # Regularization must be a float or match the dispersion size.
53
        with self.assertRaises(ValueError):
54
            m.regularization = [0., 1.]
55
        m.regularization = np.zeros_like(m.dispersion)
56
        m.train()
57