1 | import datetime |
||
2 | from time import sleep |
||
3 | |||
4 | import pytest |
||
5 | from sacred import Experiment |
||
6 | from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics |
||
7 | |||
8 | |||
9 | @pytest.fixture() |
||
10 | def ex(): |
||
11 | return Experiment("Test experiment") |
||
12 | |||
13 | |||
14 | View Code Duplication | def test_log_scalar_metric_with_run(ex): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
15 | START = 10 |
||
16 | END = 100 |
||
17 | STEP_SIZE = 5 |
||
18 | messages = {} |
||
19 | @ex.main |
||
20 | def main_function(_run): |
||
21 | # First, make sure the queue is empty: |
||
22 | assert len(ex.current_run._metrics.get_last_metrics()) == 0 |
||
23 | for i in range(START, END, STEP_SIZE): |
||
24 | val = i*i |
||
25 | _run.log_scalar("training.loss", val, i) |
||
26 | messages["messages"] = ex.current_run._metrics.get_last_metrics() |
||
27 | """Calling get_last_metrics clears the metrics logger internal queue. |
||
28 | If we don't call it here, it would be called during Sacred heartbeat |
||
29 | event after the run finishes, and the data we want to test would |
||
30 | be lost.""" |
||
31 | ex.run() |
||
32 | assert ex.current_run is not None |
||
33 | messages = messages["messages"] |
||
34 | assert len(messages) == (END - START)/STEP_SIZE |
||
35 | for i in range(len(messages)-1): |
||
36 | assert messages[i].step < messages[i+1].step |
||
37 | assert messages[i].step == START + i * STEP_SIZE |
||
38 | assert messages[i].timestamp <= messages[i + 1].timestamp |
||
39 | |||
40 | |||
41 | View Code Duplication | def test_log_scalar_metric_with_ex(ex): |
|
0 ignored issues
–
show
|
|||
42 | messages = {} |
||
43 | START = 10 |
||
44 | END = 100 |
||
45 | STEP_SIZE = 5 |
||
46 | @ex.main |
||
47 | def main_function(_run): |
||
48 | for i in range(START, END, STEP_SIZE): |
||
49 | val = i*i |
||
50 | ex.log_scalar("training.loss", val, i) |
||
51 | messages["messages"] = ex.current_run._metrics.get_last_metrics() |
||
52 | ex.run() |
||
53 | assert ex.current_run is not None |
||
54 | messages = messages["messages"] |
||
55 | assert len(messages) == (END - START) / STEP_SIZE |
||
56 | for i in range(len(messages)-1): |
||
57 | assert messages[i].step < messages[i+1].step |
||
58 | assert messages[i].step == START + i * STEP_SIZE |
||
59 | assert messages[i].timestamp <= messages[i + 1].timestamp |
||
60 | |||
61 | |||
62 | def test_log_scalar_metric_with_implicit_step(ex): |
||
63 | messages = {} |
||
64 | @ex.main |
||
65 | def main_function(_run): |
||
66 | for i in range(10): |
||
67 | val = i*i |
||
68 | ex.log_scalar("training.loss", val) |
||
69 | messages["messages"] = ex.current_run._metrics.get_last_metrics() |
||
70 | ex.run() |
||
71 | assert ex.current_run is not None |
||
72 | messages = messages["messages"] |
||
73 | assert len(messages) == 10 |
||
74 | for i in range(len(messages)-1): |
||
75 | assert messages[i].step < messages[i+1].step |
||
76 | assert messages[i].step == i |
||
77 | assert messages[i].timestamp <= messages[i + 1].timestamp |
||
78 | |||
79 | |||
80 | def test_log_scalar_metrics_with_implicit_step(ex): |
||
81 | messages = {} |
||
82 | |||
83 | @ex.main |
||
84 | def main_function(_run): |
||
85 | for i in range(10): |
||
86 | val = i*i |
||
87 | ex.log_scalar("training.loss", val) |
||
88 | ex.log_scalar("training.accuracy", val + 1) |
||
89 | messages["messages"] = ex.current_run._metrics.get_last_metrics() |
||
90 | ex.run() |
||
91 | assert ex.current_run is not None |
||
92 | messages = messages["messages"] |
||
93 | tr_loss_messages = [m for m in messages if m.name == "training.loss"] |
||
94 | tr_acc_messages = [m for m in messages if m.name == "training.accuracy"] |
||
95 | |||
96 | assert len(tr_loss_messages) == 10 |
||
97 | # both should have 10 records |
||
98 | assert len(tr_acc_messages) == len(tr_loss_messages) |
||
99 | for i in range(len(tr_loss_messages) - 1): |
||
100 | assert tr_loss_messages[i].step < tr_loss_messages[i + 1].step |
||
101 | assert tr_loss_messages[i].step == i |
||
102 | assert tr_loss_messages[i].timestamp <= tr_loss_messages[i + 1].timestamp |
||
103 | |||
104 | assert tr_acc_messages[i].step < tr_acc_messages[i + 1].step |
||
105 | assert tr_acc_messages[i].step == i |
||
106 | assert tr_acc_messages[i].timestamp <= tr_acc_messages[i + 1].timestamp |
||
107 | |||
108 | |||
109 | def test_linearize_metrics(): |
||
110 | entries = [ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 100), |
||
111 | ScalarMetricLogEntry("training.accuracy", 5, datetime.datetime.utcnow(), 50), |
||
112 | ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 200), |
||
113 | ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), |
||
114 | ScalarMetricLogEntry("training.accuracy", 15, datetime.datetime.utcnow(), 150), |
||
115 | ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300)] |
||
116 | linearized = linearize_metrics(entries) |
||
117 | assert type(linearized) == dict |
||
118 | assert len(linearized.keys()) == 2 |
||
119 | assert "training.loss" in linearized |
||
120 | assert "training.accuracy" in linearized |
||
121 | assert len(linearized["training.loss"]["steps"]) == 2 |
||
122 | assert len(linearized["training.loss"]["values"]) == 2 |
||
123 | assert len(linearized["training.loss"]["timestamps"]) == 2 |
||
124 | assert len(linearized["training.accuracy"]["steps"]) == 4 |
||
125 | assert len(linearized["training.accuracy"]["values"]) == 4 |
||
126 | assert len(linearized["training.accuracy"]["timestamps"]) == 4 |
||
127 | assert linearized["training.accuracy"]["steps"] == [5, 10, 15, 30] |
||
128 | assert linearized["training.accuracy"]["values"] == [50, 100, 150, 300] |
||
129 | assert linearized["training.loss"]["steps"] == [10, 20] |
||
130 | assert linearized["training.loss"]["values"] == [100, 200] |
||
131 |