Total Complexity | 5 |
Total Lines | 28 |
Duplicated Lines | 0 % |
Changes | 1 | ||
Bugs | 0 | Features | 0 |
1 | #!/usr/bin/env python |
||
9 | class Runtime(object): |
||
10 | """ |
||
11 | Manage runtime variables in deepy. |
||
12 | """ |
||
13 | |||
14 | def __init__(self): |
||
15 | self._training_flag = theano.shared(0, name="training_flag") |
||
16 | self._is_training = False |
||
17 | |||
18 | |||
19 | @neural_computation |
||
20 | def iftrain(self, then_branch, else_branch): |
||
21 | """ |
||
22 | Execute `then_branch` when training. |
||
23 | """ |
||
24 | return ifelse(self._training_flag, then_branch, else_branch, name="iftrain") |
||
25 | |||
26 | def switch_training(self, flag): |
||
27 | """ |
||
28 | Switch training mode. |
||
29 | :param flag: switch on training mode when flag is True. |
||
30 | """ |
||
31 | if self._is_training == flag: return |
||
32 | self._is_training = flag |
||
33 | if flag: |
||
34 | self._training_flag.set_value(1) |
||
35 | else: |
||
36 | self._training_flag.set_value(0) |
||
37 | |||
40 | runtime = Runtime() |