Total Complexity | 2 |
Total Lines | 24 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # coding=utf-8 |
||
2 | |||
3 | """ |
||
4 | Tests for deepreg/model/optimizer.py |
||
5 | pytest style |
||
6 | """ |
||
7 | import tensorflow as tf |
||
8 | |||
9 | import deepreg.model.optimizer as optimizer |
||
10 | |||
11 | |||
12 | class TestBuildOptimizer: |
||
13 | def test_build_optimizer_adam(self): |
||
14 | """Build an Adam optimizer""" |
||
15 | opt_config = {"name": "Adam", "learning_rate": 1.0e-5} |
||
16 | opt_get = optimizer.build_optimizer(opt_config) |
||
17 | assert isinstance(opt_get, tf.keras.optimizers.Adam) |
||
18 | |||
19 | def test_build_optimizer_sgd(self): |
||
20 | """Build an SGD optimizer""" |
||
21 | opt_config = {"name": "SGD"} |
||
22 | opt_get = optimizer.build_optimizer(opt_config) |
||
23 | assert isinstance(opt_get, tf.keras.optimizers.SGD) |
||
24 |