| Total Complexity | 2 |
| Total Lines | 24 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | # coding=utf-8 |
||
| 2 | |||
| 3 | """ |
||
| 4 | Tests for deepreg/model/optimizer.py |
||
| 5 | pytest style |
||
| 6 | """ |
||
| 7 | import tensorflow as tf |
||
| 8 | |||
| 9 | import deepreg.model.optimizer as optimizer |
||
| 10 | |||
| 11 | |||
| 12 | class TestBuildOptimizer: |
||
|
|
|||
| 13 | def test_build_optimizer_adam(self): |
||
| 14 | """Build an Adam optimizer""" |
||
| 15 | opt_config = {"name": "Adam", "learning_rate": 1.0e-5} |
||
| 16 | opt_get = optimizer.build_optimizer(opt_config) |
||
| 17 | assert isinstance(opt_get, tf.keras.optimizers.Adam) |
||
| 18 | |||
| 19 | def test_build_optimizer_sgd(self): |
||
| 20 | """Build an SGD optimizer""" |
||
| 21 | opt_config = {"name": "SGD"} |
||
| 22 | opt_get = optimizer.build_optimizer(opt_config) |
||
| 23 | assert isinstance(opt_get, tf.keras.optimizers.SGD) |
||
| 24 |