Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

FileStorageObserver   B

Complexity

Total Complexity 40

Size/Duplication

Total Lines 183
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 1
Metric Value
dl 0
loc 183
rs 8.2608
c 2
b 0
f 1
wmc 40

18 Methods

Rating   Name   Duplication   Size   Complexity  
A save_sources() 0 9 2
A artifact_event() 0 4 1
A failed_event() 0 6 1
B queued_event() 0 24 4
A save_cout() 0 3 2
B started_event() 0 33 6
A resource_event() 0 4 1
A save_json() 0 3 2
A completed_event() 0 7 1
A __init__() 0 10 1
A interrupted_event() 0 5 1
A save_file() 0 3 1
B create() 0 16 5
A __eq__() 0 4 2
A find_or_save() 0 10 3
A render_template() 0 12 4
A heartbeat_event() 0 8 2
A __ne__() 0 2 1

How to fix   Complexity   

Complex Class

Complex classes like FileStorageObserver often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import json
5
import os
6
import os.path
7
import tempfile
8
9
from datetime import datetime
10
from shutil import copyfile
11
12
from sacred.commandline_options import CommandLineOption
13
from sacred.dependencies import get_digest
14
from sacred.observers.base import RunObserver
15
from sacred.utils import FileNotFoundError  # For compatibility with py2
16
from sacred import optional as opt
17
from sacred.serializer import flatten
18
19
20
def json_serial(obj):
21
    """JSON serializer for objects not serializable by default json code."""
22
    if isinstance(obj, datetime):
23
        serial = obj.isoformat()
24
        return serial
25
    raise TypeError("Type not serializable")
26
27
28
class FileStorageObserver(RunObserver):
29
    VERSION = 'FileStorageObserver-0.7.0'
30
31
    @classmethod
32
    def create(cls, basedir, resource_dir=None, source_dir=None,
33
               template=None):
34
        if not os.path.exists(basedir):
35
            os.makedirs(basedir)
36
        resource_dir = resource_dir or os.path.join(basedir, '_resources')
37
        source_dir = source_dir or os.path.join(basedir, '_sources')
38
        if template is not None:
39
            if not os.path.exists(template):
40
                raise FileNotFoundError("Couldn't find template file '{}'"
41
                                        .format(template))
42
        else:
43
            template = os.path.join(basedir, 'template.html')
44
            if not os.path.exists(template):
45
                template = None
46
        return cls(basedir, resource_dir, source_dir, template)
47
48
    def __init__(self, basedir, resource_dir, source_dir, template):
49
        self.basedir = basedir
50
        self.resource_dir = resource_dir
51
        self.source_dir = source_dir
52
        self.template = template
53
        self.dir = None
54
        self.run_entry = None
55
        self.config = None
56
        self.info = None
57
        self.cout = ""
58
59
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
60
                     _id):
61
        if _id is None:
62
            self.dir = tempfile.mkdtemp(prefix='run_', dir=self.basedir)
63
        else:
64
            self.dir = os.path.join(self.basedir, str(_id))
65
            os.mkdir(self.dir)
66
67
        self.run_entry = {
68
            'experiment': dict(ex_info),
69
            'command': command,
70
            'meta': meta_info,
71
            'status': 'QUEUED',
72
        }
73
        self.config = config
74
        self.info = {}
75
76
        self.save_json(self.run_entry, 'run.json')
77
        self.save_json(self.config, 'config.json')
78
79
        for s, m in ex_info['sources']:
80
            self.save_file(s)
81
82
        return os.path.relpath(self.dir, self.basedir) if _id is None else _id
83
84
    def save_sources(self, ex_info):
85
        base_dir = ex_info['base_dir']
86
        source_info = []
87
        for s, m in ex_info['sources']:
88
            abspath = os.path.join(base_dir, s)
89
            store_path, md5sum = self.find_or_save(abspath, self.source_dir)
90
            # assert m == md5sum
91
            source_info.append([s, os.path.relpath(store_path, self.basedir)])
92
        return source_info
93
94
    def started_event(self, ex_info, command, host_info, start_time, config,
95
                      meta_info, _id):
96
        if _id is None:
97
            dir_nrs = [int(d) for d in os.listdir(self.basedir)
98
                       if os.path.isdir(os.path.join(self.basedir, d)) and
99
                       d.isdigit()]
100
            _id = max(dir_nrs + [0]) + 1
101
102
        self.dir = os.path.join(self.basedir, str(_id))
103
        os.mkdir(self.dir)
104
105
        ex_info['sources'] = self.save_sources(ex_info)
106
107
        self.run_entry = {
108
            'experiment': dict(ex_info),
109
            'command': command,
110
            'host': dict(host_info),
111
            'start_time': start_time.isoformat(),
112
            'meta': meta_info,
113
            'status': 'RUNNING',
114
            'resources': [],
115
            'artifacts': [],
116
            'heartbeat': None
117
        }
118
        self.config = config
119
        self.info = {}
120
        self.cout = ""
121
122
        self.save_json(self.run_entry, 'run.json')
123
        self.save_json(self.config, 'config.json')
124
        self.save_cout()
125
126
        return os.path.relpath(self.dir, self.basedir) if _id is None else _id
127
128
    def find_or_save(self, filename, store_dir):
129
        if not os.path.exists(store_dir):
130
            os.makedirs(store_dir)
131
        source_name, ext = os.path.splitext(os.path.basename(filename))
132
        md5sum = get_digest(filename)
133
        store_name = source_name + '_' + md5sum + ext
134
        store_path = os.path.join(store_dir, store_name)
135
        if not os.path.exists(store_path):
136
            copyfile(filename, store_path)
137
        return store_path, md5sum
138
139
    def save_json(self, obj, filename):
140
        with open(os.path.join(self.dir, filename), 'w') as f:
141
            json.dump(flatten(obj), f, sort_keys=True, indent=2)
142
143
    def save_file(self, filename, target_name=None):
144
        target_name = target_name or os.path.basename(filename)
145
        copyfile(filename, os.path.join(self.dir, target_name))
146
147
    def save_cout(self):
148
        with open(os.path.join(self.dir, 'cout.txt'), 'w') as f:
149
            f.write(self.cout)
150
151
    def render_template(self):
152
        if opt.has_mako and self.template:
153
            from mako.template import Template
154
            template = Template(filename=self.template)
155
            report = template.render(run=self.run_entry,
156
                                     config=self.config,
157
                                     info=self.info,
158
                                     cout=self.cout,
159
                                     savedir=self.dir)
160
            _, ext = os.path.splitext(self.template)
161
            with open(os.path.join(self.dir, 'report' + ext), 'w') as f:
162
                f.write(report)
163
164
    def heartbeat_event(self, info, captured_out, beat_time):
165
        self.info = info
166
        self.run_entry['heartbeat'] = beat_time.isoformat()
167
        self.cout = captured_out
168
        self.save_cout()
169
        self.save_json(self.run_entry, 'run.json')
170
        if self.info:
171
            self.save_json(self.info, 'info.json')
172
173
    def completed_event(self, stop_time, result):
174
        self.run_entry['stop_time'] = stop_time.isoformat()
175
        self.run_entry['result'] = result
176
        self.run_entry['status'] = 'COMPLETED'
177
178
        self.save_json(self.run_entry, 'run.json')
179
        self.render_template()
180
181
    def interrupted_event(self, interrupt_time, status):
182
        self.run_entry['stop_time'] = interrupt_time.isoformat()
183
        self.run_entry['status'] = status
184
        self.save_json(self.run_entry, 'run.json')
185
        self.render_template()
186
187
    def failed_event(self, fail_time, fail_trace):
188
        self.run_entry['stop_time'] = fail_time.isoformat()
189
        self.run_entry['status'] = 'FAILED'
190
        self.run_entry['fail_trace'] = fail_trace
191
        self.save_json(self.run_entry, 'run.json')
192
        self.render_template()
193
194
    def resource_event(self, filename):
195
        store_path, md5sum = self.find_or_save(filename, self.resource_dir)
196
        self.run_entry['resources'].append([filename, store_path])
197
        self.save_json(self.run_entry, 'run.json')
198
199
    def artifact_event(self, name, filename):
200
        self.save_file(filename, name)
201
        self.run_entry['artifacts'].append(name)
202
        self.save_json(self.run_entry, 'run.json')
203
204
    def __eq__(self, other):
205
        if isinstance(other, FileStorageObserver):
206
            return self.basedir == other.basedir
207
        return False
208
209
    def __ne__(self, other):
210
        return not self.__eq__(other)
211
212
213
class FileStorageOption(CommandLineOption):
214
    """Add a file-storage observer to the experiment."""
215
216
    short_flag = 'F'
217
    arg = 'BASEDIR'
218
    arg_description = "Base-directory to write the runs to"
219
220
    @classmethod
221
    def apply(cls, args, run):
222
        run.observers.append(FileStorageObserver.create(args))
223