Test Failed
Pull Request — master (#708)
by Daniil
03:33
created

scripts.dawn_runner.run_savu.process_init()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 1
dl 0
loc 10
rs 9.9
c 0
b 0
f 0
1
'''
2
run_savu
3
This is a refactor of the code that used to be contained in dawn.
4
It's used to mock up a runner for individual savu plugins from a python shell.
5
It is currently very early in development and will be subject to massive refactor in the future.
6
'''
7
from savu.data.experiment_collection import Experiment
8
from savu.data.meta_data import MetaData
9
from savu.plugins.utils import get_plugin
10
import savu.plugins.loaders.utils.yaml_utils as yaml
11
import os, sys
12
import numpy as np
13
from copy import deepcopy as copy
14
import time
15
from collections import OrderedDict
16
17
18
def get_output_rank(path2plugin, inputs, params, persistence):
19
    sys_path_0_lock = persistence['sys_path_0_lock']
20
    sys_path_0_lock.acquire()
21
    try:
22
        parameters = {}
23
        # slight repack here
24
        for key in list(params.keys()):
25
            val = params[key]["value"]
26
            if type(val)==type(''):
27
                val = val.replace('\n','').strip()
28
                parameters[key] = val
29
        plugin = _savu_setup(path2plugin, inputs, parameters)
30
        persistence['plugin_object'] = plugin
31
    finally:
32
        sys_path_0_lock.release()
33
34
    return len(plugin.get_plugin_out_datasets()[0].get_core_dimensions())
35
36
37
def runSavu(path2plugin, params, metaOnly, inputs, persistence):
38
    '''
39
    path2plugin  - is the path to the user script that should be run
40
    params - are the savu parameters
41
    metaOnly - a boolean for whether the data is kept in metadata or is passed as data
42
    inputs      - is a dictionary of input objects 
43
    '''
44
    t1 = time.time()
45
    sys_path_0_lock = persistence['sys_path_0_lock']
46
    sys_path_0_set = persistence['sys_path_0_set']
47
    plugin_object = persistence['plugin_object']
48
    axis_labels = persistence['axis_labels']
49
    axis_values = persistence['axis_values']
50
    string_key = persistence['string_key']
51
    parameters = persistence['parameters']
52
    aux = persistence['aux']
53
    sys_path_0_lock.acquire()
54
    try:
55
        result = copy(inputs)
56
57
        scriptDir = os.path.dirname(path2plugin)
58
        sys_path_0 = sys.path[0]
59
        if sys_path_0_set and scriptDir != sys_path_0:
60
            raise Exception("runSavu attempted to change sys.path[0] in a way that "
61
                            "could cause a race condition. Current sys.path[0] is {!r}, "
62
                            "trying to set to {!r}".format(sys_path_0, scriptDir))
63
        else:
64
            sys.path[0] = scriptDir
65
            sys_path_0_set = True
66
        
67
        if not plugin_object:
68
            parameters = {}
69
                # slight repack here
70
            for key in list(params.keys()):
71
#                 print "here"
72
                val = params[key]["value"]
73
                if type(val)==type(''):
74
                    val = val.replace('\n','').strip()
75
#                 print val
76
                parameters[key] = val
77
                print(("val: {}".format(val)))
78
#             print "initialising the object"
79
            plugin_object = _savu_setup(path2plugin, inputs, parameters)
80
            persistence['plugin_object'] = plugin_object
81
            axis_labels, axis_values = process_init(plugin_object)
82
#             print "I did the initialisation"
83
#             print "axis labels",axis_labels
84
#             print "axis_values", axis_values
85
#             print plugin_object
86
            chkstring =  [any(isinstance(ix, str) for ix in axis_values[label]) for label in axis_labels]
87
            if any(chkstring): # are any axis values strings we instead make this an aux out
88
                metaOnly = True
89
#                 print "AXIS LABELS"+str(axis_values)
90
                string_key = axis_labels[chkstring.index(True)]
91
                aux = OrderedDict.fromkeys(axis_values[string_key])
92
#                 print aux.keys()
93
            else:
94
                string_key = axis_labels[0]# will it always be the first one?
95
            if not metaOnly:
96
                if len(axis_labels) == 1:
97
                    result['xaxis']=axis_values[axis_labels[0]]
98
                    result['xaxis_title']=axis_labels[0]
99
                if len(axis_labels) == 2:
100
#                     print "set the output axes"
101
                    x = axis_labels[0]
102
                    result['xaxis_title']=x
103
                    y = axis_labels[1]
104
                    result['yaxis_title']=y
105
                    result['yaxis']=axis_values[y]
106
                    result['xaxis']=axis_values[x]
107
        else:
108
            pass
109
    finally:
110
        sys_path_0_lock.release()
111
112
    if plugin_object.get_max_frames()>1: # we need to get round this since we are frame independant
113
        data = np.expand_dims(inputs['data'], 0)
114
    else:
115
        data = inputs['data']
116
117
    print(("metaOnly: {}".format(metaOnly)))
118
119
    if not metaOnly: 
120
121
        out = plugin_object.process_frames([data])
122
#         print "ran the plugin"
123
124
        result['data'] = out
125
    elif metaOnly:
126
        result['data'] = inputs['data']
127
#         print type(result['data'])
128
        out_array = plugin_object.process_frames([inputs['data']])
129
130
#         print aux.keys()
131
132
        for k,key in enumerate(aux.keys()):
133
            aux[key]=np.array([out_array[k]])# wow really
134
135
        result['auxiliary'] = aux
136
    t2 = time.time()
137
    print("time to runSavu = "+str((t2-t1)))
138
    return result
139
140
141
def _savu_setup(path2plugin, inputs, parameters):
142
    print("running _savu_setup")
143
    parameters['in_datasets'] = [inputs['dataset_name']]
144
    parameters['out_datasets'] = [inputs['dataset_name']]
145
    plugin = get_plugin(path2plugin.split('.py')[0]+'.py')
146
    plugin.exp = setup_exp_and_data(inputs, inputs['data'], plugin)
147
    plugin._set_parameters(parameters)
148
    plugin._set_plugin_datasets()
149
    plugin.setup()
150
    return plugin
151
152
153
def process_init(plugin):
154
    axis_labels = plugin.get_out_datasets()[0].get_axis_label_keys()
155
    axis_labels.remove('idx')  # get the labels
156
    axis_values = {}
157
    plugin._clean_up()  # this copies the metadata!
158
    for label in axis_labels:
159
        axis_values[label] = plugin.get_out_datasets()[0].meta_data.get(label)
160
    plugin.base_pre_process()
161
    plugin.pre_process()
162
    return axis_labels, axis_values
163
164
165
def setup_exp_and_data(inputs, data, plugin):
166
    exp = DawnExperiment(get_options())
167
    data_obj = exp.create_data_object('in_data', inputs['dataset_name'])
168
    data_obj.data = None
169
    if len(inputs['data'].shape)==1:
170
#         print data.shape
171
        if inputs['xaxis_title'] is None or inputs['xaxis_title'].isspace():
172
            inputs['xaxis_title']='x'
173
            inputs['xaxis'] = np.arange(inputs['data'].shape[0])
174
        data_obj.set_axis_labels('idx.units', inputs['xaxis_title'] + '.units')
175
        data_obj.meta_data.set('idx', np.array([1]))
176
        data_obj.meta_data.set(str(inputs['xaxis_title']), inputs['xaxis'])
177
        data_obj.add_pattern(plugin.get_plugin_pattern(), core_dims=(1,), slice_dims=(0, ))
178
        data_obj.add_pattern('SINOGRAM', core_dims=(1,), slice_dims=(0, )) # good to add these two on too
179
        data_obj.add_pattern('PROJECTION', core_dims=(1,), slice_dims=(0, ))
180
    if len(inputs['data'].shape)==2:
181
        if inputs['xaxis_title'] is None  or inputs['xaxis_title'].isspace():
182
            print("set x")
183
            inputs['xaxis_title']='x'
184
            inputs['xaxis'] = np.arange(inputs['data'].shape[0])
185
        if inputs['yaxis_title'] is None or inputs['yaxis_title'].isspace():
186
            print("set y")
187
            inputs['yaxis_title']='y'
188
            size_y_axis = inputs['data'].shape[1]
189
            inputs['yaxis'] = np.arange(size_y_axis)
190
        
191
        data_obj.set_axis_labels('idx.units', inputs['xaxis_title'] + '.units', inputs['yaxis_title'] + '.units')
192
        data_obj.meta_data.set('idx', np.array([1]))
193
        data_obj.meta_data.set(str(inputs['xaxis_title']), inputs['xaxis'])
194
        data_obj.meta_data.set(str(inputs['yaxis_title']), inputs['yaxis'])
195
        data_obj.add_pattern(plugin.get_plugin_pattern(), core_dims=(1,2,), slice_dims=(0, ))
196
        data_obj.add_pattern('SINOGRAM', core_dims=(1,2,), slice_dims=(0, )) # good to add these two on too
197
        data_obj.add_pattern('PROJECTION', core_dims=(1,2,), slice_dims=(0, ))
198
   
199
    data_obj.set_shape((1, ) + data.shape) # need to add for now for slicing...
200
    data_obj.get_preview().set_preview([])
201
    return exp
202
203
class DawnExperiment(Experiment):
204
    def __init__(self, options):
205
        self.index={"in_data": {}, "out_data": {}, "mapping": {}}
206
        self.meta_data = MetaData(get_options())
207
        self.nxs_file = None
208
209
def get_options():
210
    options = {}
211
    options['dawn_runner'] = True
212
    options['transport'] = 'hdf5'
213
    options['process_names'] = 'CPU0'
214
    options['processes'] = 'CPU0'
215
    options['data_file'] = ''
216
    options['process_file'] = ''
217
    options['out_path'] = ''
218
    options['inter_path'] = ''
219
    options['log_path'] = ''
220
    options['run_type'] = ''
221
    options['verbose'] = 'True'
222
    options['system_params'] = _set_system_params()
223
    return options
224
225
def _set_system_params():
226
    # look in conda environment to see which version is being used
227
    savu_path = sys.modules['savu'].__path__[0]
228
    sys_files = os.path.join(
229
            os.path.dirname(savu_path), 'system_files')
230
    subdirs = os.listdir(sys_files)
231
    sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
232
    fname = 'system_parameters.yml'
233
    sys_file = os.path.join(sys_files, sys_folder, fname)
234
    return yaml.read_yaml(sys_file)
235