| Total Complexity | 5 |
| Total Lines | 28 |
| Duplicated Lines | 0 % |
| Changes | 1 | ||
| Bugs | 0 | Features | 0 |
| 1 | #!/usr/bin/env python |
||
| 9 | class Runtime(object): |
||
| 10 | """ |
||
| 11 | Manage runtime variables in deepy. |
||
| 12 | """ |
||
| 13 | |||
| 14 | def __init__(self): |
||
| 15 | self._training_flag = theano.shared(0, name="training_flag") |
||
| 16 | self._is_training = False |
||
| 17 | |||
| 18 | |||
| 19 | @neural_computation |
||
| 20 | def iftrain(self, then_branch, else_branch): |
||
| 21 | """ |
||
| 22 | Execute `then_branch` when training. |
||
| 23 | """ |
||
| 24 | return ifelse(self._training_flag, then_branch, else_branch, name="iftrain") |
||
| 25 | |||
| 26 | def switch_training(self, flag): |
||
| 27 | """ |
||
| 28 | Switch training mode. |
||
| 29 | :param flag: switch on training mode when flag is True. |
||
| 30 | """ |
||
| 31 | if self._is_training == flag: return |
||
| 32 | self._is_training = flag |
||
| 33 | if flag: |
||
| 34 | self._training_flag.set_value(1) |
||
| 35 | else: |
||
| 36 | self._training_flag.set_value(0) |
||
| 37 | |||
| 40 | runtime = Runtime() |