Completed
Push — master ( aa85c8...9f4bcb )
by Klaus
01:00
created

FileStorageObserver.log_metrics()   B

Complexity

Conditions 5

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 5
c 1
b 0
f 0
dl 0
loc 27
rs 8.0894
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import json
5
import os
6
import os.path
7
import tempfile
8
9
from shutil import copyfile
10
11
from sacred.commandline_options import CommandLineOption
12
from sacred.dependencies import get_digest
13
from sacred.observers.base import RunObserver
14
from sacred.utils import FileNotFoundError  # For compatibility with py2
15
from sacred import optional as opt
16
from sacred.serializer import flatten
17
18
19
DEFAULT_FILE_STORAGE_PRIORITY = 20
20
21
22
class FileStorageObserver(RunObserver):
23
    VERSION = 'FileStorageObserver-0.7.0'
24
25
    @classmethod
26
    def create(cls, basedir, resource_dir=None, source_dir=None,
27
               template=None, priority=DEFAULT_FILE_STORAGE_PRIORITY):
28
        if not os.path.exists(basedir):
29
            os.makedirs(basedir)
30
        resource_dir = resource_dir or os.path.join(basedir, '_resources')
31
        source_dir = source_dir or os.path.join(basedir, '_sources')
32
        if template is not None:
33
            if not os.path.exists(template):
34
                raise FileNotFoundError("Couldn't find template file '{}'"
35
                                        .format(template))
36
        else:
37
            template = os.path.join(basedir, 'template.html')
38
            if not os.path.exists(template):
39
                template = None
40
        return cls(basedir, resource_dir, source_dir, template, priority)
41
42
    def __init__(self, basedir, resource_dir, source_dir, template,
43
                 priority=DEFAULT_FILE_STORAGE_PRIORITY):
44
        self.basedir = basedir
45
        self.resource_dir = resource_dir
46
        self.source_dir = source_dir
47
        self.template = template
48
        self.priority = priority
49
        self.dir = None
50
        self.run_entry = None
51
        self.config = None
52
        self.info = None
53
        self.cout = ""
54
55
    def queued_event(self, ex_info, command, host_info, queue_time, config,
56
                     meta_info, _id):
57
        if _id is None:
58
            self.dir = tempfile.mkdtemp(prefix='run_', dir=self.basedir)
59
        else:
60
            self.dir = os.path.join(self.basedir, str(_id))
61
            os.mkdir(self.dir)
62
63
        self.run_entry = {
64
            'experiment': dict(ex_info),
65
            'command': command,
66
            'host': dict(host_info),
67
            'meta': meta_info,
68
            'status': 'QUEUED',
69
        }
70
        self.config = config
71
        self.info = {}
72
73
        self.save_json(self.run_entry, 'run.json')
74
        self.save_json(self.config, 'config.json')
75
76
        for s, m in ex_info['sources']:
77
            self.save_file(s)
78
79
        return os.path.relpath(self.dir, self.basedir) if _id is None else _id
80
81
    def save_sources(self, ex_info):
82
        base_dir = ex_info['base_dir']
83
        source_info = []
84
        for s, m in ex_info['sources']:
85
            abspath = os.path.join(base_dir, s)
86
            store_path, md5sum = self.find_or_save(abspath, self.source_dir)
87
            # assert m == md5sum
88
            source_info.append([s, os.path.relpath(store_path, self.basedir)])
89
        return source_info
90
91
    def started_event(self, ex_info, command, host_info, start_time, config,
92
                      meta_info, _id):
93
        if _id is None:
94
            for i in range(200):
95
                dir_nrs = [int(d) for d in os.listdir(self.basedir)
96
                           if os.path.isdir(os.path.join(self.basedir, d)) and
97
                           d.isdigit()]
98
                _id = max(dir_nrs + [0]) + 1
99
                self.dir = os.path.join(self.basedir, str(_id))
100
                try:
101
                    os.mkdir(self.dir)
102
                    break
103
                except FileExistsError:  # Catch race conditions
104
                    if i > 100:
105
                        # After some tries,
106
                        # expect that something other went wrong
107
                        raise
108
        else:
109
            self.dir = os.path.join(self.basedir, str(_id))
110
            os.mkdir(self.dir)
111
112
        ex_info['sources'] = self.save_sources(ex_info)
113
114
        self.run_entry = {
115
            'experiment': dict(ex_info),
116
            'command': command,
117
            'host': dict(host_info),
118
            'start_time': start_time.isoformat(),
119
            'meta': meta_info,
120
            'status': 'RUNNING',
121
            'resources': [],
122
            'artifacts': [],
123
            'heartbeat': None
124
        }
125
        self.config = config
126
        self.info = {}
127
        self.cout = ""
128
129
        self.save_json(self.run_entry, 'run.json')
130
        self.save_json(self.config, 'config.json')
131
        self.save_cout()
132
133
        return os.path.relpath(self.dir, self.basedir) if _id is None else _id
134
135
    def find_or_save(self, filename, store_dir):
136
        if not os.path.exists(store_dir):
137
            os.makedirs(store_dir)
138
        source_name, ext = os.path.splitext(os.path.basename(filename))
139
        md5sum = get_digest(filename)
140
        store_name = source_name + '_' + md5sum + ext
141
        store_path = os.path.join(store_dir, store_name)
142
        if not os.path.exists(store_path):
143
            copyfile(filename, store_path)
144
        return store_path, md5sum
145
146
    def save_json(self, obj, filename):
147
        with open(os.path.join(self.dir, filename), 'w') as f:
148
            json.dump(flatten(obj), f, sort_keys=True, indent=2)
149
150
    def save_file(self, filename, target_name=None):
151
        target_name = target_name or os.path.basename(filename)
152
        copyfile(filename, os.path.join(self.dir, target_name))
153
154
    def save_cout(self):
155
        with open(os.path.join(self.dir, 'cout.txt'), 'wb') as f:
156
            f.write(self.cout.encode('utf-8'))
157
158
    def render_template(self):
159
        if opt.has_mako and self.template:
160
            from mako.template import Template
161
            template = Template(filename=self.template)
162
            report = template.render(run=self.run_entry,
163
                                     config=self.config,
164
                                     info=self.info,
165
                                     cout=self.cout,
166
                                     savedir=self.dir)
167
            _, ext = os.path.splitext(self.template)
168
            with open(os.path.join(self.dir, 'report' + ext), 'w') as f:
169
                f.write(report)
170
171
    def heartbeat_event(self, info, captured_out, beat_time, result):
172
        self.info = info
173
        self.run_entry['heartbeat'] = beat_time.isoformat()
174
        self.run_entry['result'] = result
175
        self.cout = captured_out
176
        self.save_cout()
177
        self.save_json(self.run_entry, 'run.json')
178
        if self.info:
179
            self.save_json(self.info, 'info.json')
180
181
    def completed_event(self, stop_time, result):
182
        self.run_entry['stop_time'] = stop_time.isoformat()
183
        self.run_entry['result'] = result
184
        self.run_entry['status'] = 'COMPLETED'
185
186
        self.save_json(self.run_entry, 'run.json')
187
        self.render_template()
188
189
    def interrupted_event(self, interrupt_time, status):
190
        self.run_entry['stop_time'] = interrupt_time.isoformat()
191
        self.run_entry['status'] = status
192
        self.save_json(self.run_entry, 'run.json')
193
        self.render_template()
194
195
    def failed_event(self, fail_time, fail_trace):
196
        self.run_entry['stop_time'] = fail_time.isoformat()
197
        self.run_entry['status'] = 'FAILED'
198
        self.run_entry['fail_trace'] = fail_trace
199
        self.save_json(self.run_entry, 'run.json')
200
        self.render_template()
201
202
    def resource_event(self, filename):
203
        store_path, md5sum = self.find_or_save(filename, self.resource_dir)
204
        self.run_entry['resources'].append([filename, store_path])
205
        self.save_json(self.run_entry, 'run.json')
206
207
    def artifact_event(self, name, filename):
208
        self.save_file(filename, name)
209
        self.run_entry['artifacts'].append(name)
210
        self.save_json(self.run_entry, 'run.json')
211
212
    def log_metrics(self, metrics_by_name, info):
213
        """Store new measurements into metrics.json.
214
        """
215
        try:
216
            metrics_path = os.path.join(self.dir, "metrics.json")
217
            saved_metrics = json.load(open(metrics_path, 'r'))
218
        except IOError as e:
219
            # We haven't recorded anything yet. Start Collecting.
220
            saved_metrics = {}
221
222
        for metric_name, metric_ptr in metrics_by_name.items():
223
224
            if metric_name not in saved_metrics:
225
                saved_metrics[metric_name] = {"values": [],
226
                                              "steps": [],
227
                                              "timestamps": []}
228
229
            saved_metrics[metric_name]["values"] += metric_ptr["values"]
230
            saved_metrics[metric_name]["steps"] += metric_ptr["steps"]
231
232
            # Manually convert them to avoid passing a datetime dtype handler
233
            # when we're trying to convert into json.
234
            timestamps_norm = [ts.isoformat()
235
                               for ts in metric_ptr["timestamps"]]
236
            saved_metrics[metric_name]["timestamps"] += timestamps_norm
237
238
        self.save_json(saved_metrics, 'metrics.json')
239
240
    def __eq__(self, other):
241
        if isinstance(other, FileStorageObserver):
242
            return self.basedir == other.basedir
243
        return False
244
245
    def __ne__(self, other):
246
        return not self.__eq__(other)
247
248
249
class FileStorageOption(CommandLineOption):
250
    """Add a file-storage observer to the experiment."""
251
252
    short_flag = 'F'
253
    arg = 'BASEDIR'
254
    arg_description = "Base-directory to write the runs to"
255
256
    @classmethod
257
    def apply(cls, args, run):
258
        run.observers.append(FileStorageObserver.create(args))
259