Completed
Pull Request — master (#384)
by
unknown
01:25
created

PulseBlock.__setitem__()   F

Complexity

Conditions 19

Size

Total Lines 50

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 19
dl 0
loc 50
rs 0.5999
c 1
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like PulseBlock.__setitem__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
3
"""
4
This file contains the Qudi data object classes needed for pulse sequence generation.
5
6
Qudi is free software: you can redistribute it and/or modify
7
it under the terms of the GNU General Public License as published by
8
the Free Software Foundation, either version 3 of the License, or
9
(at your option) any later version.
10
11
Qudi is distributed in the hope that it will be useful,
12
but WITHOUT ANY WARRANTY; without even the implied warranty of
13
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
GNU General Public License for more details.
15
16
You should have received a copy of the GNU General Public License
17
along with Qudi. If not, see <http://www.gnu.org/licenses/>.
18
19
Copyright (c) the Qudi Developers. See the COPYRIGHT.txt file at the
20
top-level directory of this distribution and at <https://github.com/Ulm-IQO/qudi/>
21
"""
22
23
import copy
24
import os
25
import sys
26
import inspect
27
import importlib
28
from collections import OrderedDict
29
30
from logic.pulsed.sampling_functions import SamplingFunctions
31
from core.util.modules import get_main_dir
32
33
34
class PulseBlockElement(object):
35
    """
36
    Object representing a single atomic element in a pulse block.
37
38
    This class can build waiting times, sine waves, etc. The pulse block may
39
    contain many Pulse_Block_Element Objects. These objects can be displayed in
40
    a GUI as single rows of a Pulse_Block.
41
    """
42
    def __init__(self, init_length_s=10e-9, increment_s=0, pulse_function=None, digital_high=None):
43
        """
44
        The constructor for a Pulse_Block_Element needs to have:
45
46
        @param float init_length_s: an initial length of the element, this parameters should not be
47
                                    zero but must have a finite value.
48
        @param float increment_s: the number which will be incremented during each repetition of
49
                                  this element.
50
        @param dict pulse_function: dictionary with keys being the qudi analog channel string
51
                                    descriptors ('a_ch1', 'a_ch2' etc.) and the corresponding
52
                                    objects being instances of the mathematical function objects
53
                                    provided by SamplingFunctions class.
54
        @param dict digital_high: dictionary with keys being the qudi digital channel string
55
                                  descriptors ('d_ch1', 'd_ch2' etc.) and the corresponding objects
56
                                  being boolean values describing if the channel should be logical
57
                                  low (False) or high (True).
58
                                  For 3 digital channel it may look like:
59
                                  {'d_ch1': True, 'd_ch2': False, 'd_ch5': False}
60
        """
61
        # FIXME: Sanity checks need to be implemented here
62
        self.init_length_s = init_length_s
63
        self.increment_s = increment_s
64
        if pulse_function is None:
65
            self.pulse_function = OrderedDict()
66
        else:
67
            self.pulse_function = pulse_function
68
        if digital_high is None:
69
            self.digital_high = OrderedDict()
70
        else:
71
            self.digital_high = digital_high
72
73
        # determine set of used digital and analog channels
74
        self.analog_channels = set(self.pulse_function)
75
        self.digital_channels = set(self.digital_high)
76
        self.channel_set = self.analog_channels.union(self.digital_channels)
77
78
    def __repr__(self):
79
        repr_str = 'PulseBlockElement(init_length_s={0}, increment_s={1}, pulse_function='.format(
80
            self.init_length_s, self.increment_s)
81
        repr_str += '{'
82
        for ind, (channel, sampling_func) in enumerate(self.pulse_function.items()):
83
            repr_str += '\'{0}\': {1}'.format(channel, 'SamplingFunctions.' + repr(sampling_func))
84
            if ind < len(self.pulse_function) - 1:
85
                repr_str += ', '
86
        repr_str += '}, '
87
        repr_str += 'digital_high={0})'.format(repr(dict(self.digital_high)))
88
        return repr_str
89
90
    def __str__(self):
91
        pulse_func_dict = {chnl: type(func).__name__ for chnl, func in self.pulse_function.items()}
92
        return_str = 'PulseBlockElement\n\tinitial length: {0}s\n\tlength increment: {1}s\n\t' \
93
                     'analog channels: {2}\n\tdigital channels: {3}'.format(self.init_length_s,
94
                                                                            self.increment_s,
95
                                                                            pulse_func_dict,
96
                                                                            dict(self.digital_high))
97
        return return_str
98
99 View Code Duplication
    def __eq__(self, other):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
100
        if not isinstance(other, PulseBlockElement):
101
            return False
102
        if self is other:
103
            return True
104
        if self.channel_set != other.channel_set:
105
            return False
106
        if (self.init_length_s, self.increment_s) != (other.init_length_s, other.increment_s):
107
            return False
108
        if set(self.digital_high.items()) != set(other.digital_high.items()):
109
            return False
110
        for chnl, func in self.pulse_function:
111
            if func != other.pulse_function[chnl]:
112
                return False
113
        return True
114
115
    def get_dict_representation(self):
116
        dict_repr = dict()
117
        dict_repr['init_length_s'] = self.init_length_s
118
        dict_repr['increment_s'] = self.increment_s
119
        dict_repr['digital_high'] = self.digital_high
120
        dict_repr['pulse_function'] = dict()
121
        for chnl, func in self.pulse_function.items():
122
            dict_repr['pulse_function'][chnl] = func.get_dict_representation()
123
        return dict_repr
124
125
    @staticmethod
126
    def element_from_dict(element_dict):
127
        for chnl, sample_dict in element_dict['pulse_function'].items():
128
            sf_class = getattr(SamplingFunctions, sample_dict['name'])
129
            element_dict['pulse_function'][chnl] = sf_class(**sample_dict['params'])
130
        return PulseBlockElement(**element_dict)
131
132
133
class PulseBlock(object):
134
    """
135
    Collection of Pulse_Block_Elements which is called a Pulse_Block.
136
    """
137
    def __init__(self, name, element_list=None):
138
        """
139
        The constructor for a Pulse_Block needs to have:
140
141
        @param str name: chosen name for the Pulse_Block
142
        @param list element_list: which contains the Pulse_Block_Element Objects forming a
143
                                  Pulse_Block, e.g. [Pulse_Block_Element, Pulse_Block_Element, ...]
144
        """
145
        self.name = name
146
        self.element_list = list() if element_list is None else element_list
147
        self.init_length_s = 0.0
148
        self.increment_s = 0.0
149
        self.analog_channels = set()
150
        self.digital_channels = set()
151
        self.channel_set = set()
152
        self.refresh_parameters()
153
        return
154
155
    def __repr__(self):
156
        repr_str = 'PulseBlock(name=\'{0}\', element_list=['.format(self.name)
157
        repr_str += ', '.join((repr(elem) for elem in self.element_list)) + '])'
158
        return repr_str
159
160
    def __str__(self):
161
        return_str = 'PulseBlock "{0}"\n\tnumber of elements: {1}\n\t'.format(
162
            self.name, len(self.element_list))
163
        return_str += 'initial length: {0}s\n\tlength increment: {1}s\n\t'.format(
164
            self.init_length_s, self.increment_s)
165
        return_str += 'active analog channels: {0}\n\tactive digital channels: {1}'.format(
166
            sorted(self.analog_channels), sorted(self.digital_channels))
167
        return return_str
168
169
    def __len__(self):
170
        return len(self.element_list)
171
172
    def __getitem__(self, key):
173
        if not isinstance(key, (slice, int)):
174
            raise TypeError('PulseBlock indices must be int or slice, not {0}'.format(type(key)))
175
        return self.element_list[key]
176
177
    def __setitem__(self, key, value):
178
        if isinstance(key, int):
179
            if not isinstance(value, PulseBlockElement):
180
                raise TypeError('PulseBlock element list entries must be of type PulseBlockElement,'
181
                                ' not {0}'.format(type(value)))
182
            if not self.channel_set:
183
                self.channel_set = value.channel_set.copy()
184
                self.analog_channels = {chnl for chnl in self.channel_set if chnl.startswith('a')}
185
                self.digital_channels = {chnl for chnl in self.channel_set if chnl.startswith('d')}
186
            elif value.channel_set != self.channel_set:
187
                raise ValueError('Usage of different sets of analog and digital channels in the '
188
                                 'same PulseBlock is prohibited. Used channel sets are:\n{0}\n{1}'
189
                                 ''.format(self.channel_set, value.channel_set))
190
191
            self.init_length_s -= self.element_list[key].init_length_s
192
            self.increment_s -= self.element_list[key].increment_s
193
            self.init_length_s += value.init_length_s
194
            self.increment_s += value.increment_s
195
        elif isinstance(key, slice):
196
            add_length = 0
197
            add_increment = 0
198
            for element in value:
199
                if not isinstance(element, PulseBlockElement):
200
                    raise TypeError('PulseBlock element list entries must be of type '
201
                                    'PulseBlockElement, not {0}'.format(type(value)))
202
                if not self.channel_set:
203
                    self.channel_set = element.channel_set.copy()
204
                    self.analog_channels = {chnl for chnl in self.channel_set if
205
                                            chnl.startswith('a')}
206
                    self.digital_channels = {chnl for chnl in self.channel_set if
207
                                             chnl.startswith('d')}
208
                elif element.channel_set != self.channel_set:
209
                    raise ValueError(
210
                        'Usage of different sets of analog and digital channels in the '
211
                        'same PulseBlock is prohibited. Used channel sets are:\n{0}\n{1}'
212
                        ''.format(self.channel_set, element.channel_set))
213
214
                add_length += element.init_length_s
215
                add_increment += element.increment_s
216
217
            for element in self.element_list[key]:
218
                self.init_length_s -= element.init_length_s
219
                self.increment_s -= element.increment_s
220
221
            self.init_length_s += add_length
222
            self.increment_s += add_increment
223
        else:
224
            raise TypeError('PulseBlock indices must be int or slice, not {0}'.format(type(key)))
225
        self.element_list[key] = copy.deepcopy(value)
226
        return
227
228
    def __delitem__(self, key):
229
        if not isinstance(key, (slice, int)):
230
            raise TypeError('PulseBlock indices must be int or slice, not {0}'.format(type(key)))
231
232
        if isinstance(key, int):
233
            items_to_delete = [self.element_list[key]]
234
        else:
235
            items_to_delete = self.element_list[key]
236
237
        for element in items_to_delete:
238
            self.init_length_s -= element.init_length_s
239
            self.increment_s -= element.increment_s
240
        del self.element_list[key]
241
        if len(self.element_list) == 0:
242
            self.init_length_s = 0.0
243
            self.increment_s = 0.0
244
        return
245
246 View Code Duplication
    def __eq__(self, other):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
247
        if not isinstance(other, PulseBlock):
248
            return False
249
        if self is other:
250
            return True
251
        if self.channel_set != other.channel_set:
252
            return False
253
        if (self.init_length_s, self.increment_s) != (other.init_length_s, other.increment_s):
254
            return False
255
        if len(self) != len(other):
256
            return False
257
        for i, element in enumerate(self.element_list):
258
            if element != other[i]:
259
                return False
260
        return True
261
262
    def refresh_parameters(self):
263
        """ Initialize the parameters which describe this Pulse_Block object.
264
265
        The information is gained from all the Pulse_Block_Element objects,
266
        which are attached in the element_list.
267
        """
268
        # the Pulse_Block parameters
269
        self.init_length_s = 0.0
270
        self.increment_s = 0.0
271
        self.channel_set = set()
272
273
        for elem in self.element_list:
274
            self.init_length_s += elem.init_length_s
275
            self.increment_s += elem.increment_s
276
277
            if not self.channel_set:
278
                self.channel_set = elem.channel_set
279
            elif self.channel_set != elem.channel_set:
280
                raise ValueError('Usage of different sets of analog and digital channels in the '
281
                                 'same PulseBlock is prohibited.\nPulseBlock creation failed!\n'
282
                                 'Used channel sets are:\n{0}\n{1}'.format(self.channel_set,
283
                                                                           elem.channel_set))
284
                break
285
        self.analog_channels = {chnl for chnl in self.channel_set if chnl.startswith('a')}
286
        self.digital_channels = {chnl for chnl in self.channel_set if chnl.startswith('d')}
287
        return
288
289
    def pop(self, position=None):
290
        if len(self.element_list) == 0:
291
            raise IndexError('pop from empty PulseBlock')
292
293
        if position is None:
294
            self.init_length_s -= self.element_list[-1].init_length_s
295
            self.increment_s -= self.element_list[-1].increment_s
296
            return self.element_list.pop()
297
298
        if not isinstance(position, int):
299
            raise TypeError('PulseBlock.pop position argument expects integer, not {0}'
300
                            ''.format(type(position)))
301
302
        if position < 0:
303
            position = len(self.element_list) + position
304
305
        if len(self.element_list) <= position or position < 0:
306
            raise IndexError('PulseBlock element list index out of range')
307
308
        self.init_length_s -= self.element_list[position].init_length_s
309
        self.increment_s -= self.element_list[position].increment_s
310
        return self.element_list.pop(position)
311
312
    def insert(self, position, element):
313
        """ Insert a PulseBlockElement at the given position. The old element at this position and
314
        all consecutive elements after that will be shifted to higher indices.
315
316
        @param int position: position in the element list
317
        @param PulseBlockElement element: PulseBlockElement instance
318
        """
319
        if not isinstance(element, PulseBlockElement):
320
            raise ValueError('PulseBlock elements must be of type PulseBlockElement, not {0}'
321
                             ''.format(type(element)))
322
323
        if position < 0:
324
            position = len(self.element_list) + position
325
326
        if len(self.element_list) < position or position < 0:
327
            raise IndexError('PulseBlock element list index out of range')
328
329
        if not self.channel_set:
330
            self.channel_set = element.channel_set.copy()
331
            self.analog_channels = {chnl for chnl in self.channel_set if chnl.startswith('a')}
332
            self.digital_channels = {chnl for chnl in self.channel_set if chnl.startswith('d')}
333
        elif element.channel_set != self.channel_set:
334
            raise ValueError('Usage of different sets of analog and digital channels in the '
335
                             'same PulseBlock is prohibited. Used channel sets are:\n{0}\n{1}'
336
                             ''.format(self.channel_set, element.channel_set))
337
338
        self.init_length_s += element.init_length_s
339
        self.increment_s += element.increment_s
340
341
        self.element_list.insert(position, copy.deepcopy(element))
342
        return
343
344
    def append(self, element):
345
        """
346
        """
347
        self.insert(position=len(self.element_list), element=element)
348
        return
349
350
    def extend(self, iterable):
351
        for element in iterable:
352
            self.append(element=element)
353
        return
354
355
    def clear(self):
356
        del self.element_list[:]
357
        self.init_length_s = 0.0
358
        self.increment_s = 0.0
359
        self.analog_channels = set()
360
        self.digital_channels = set()
361
        self.channel_set = set()
362
        return
363
364
    def reverse(self):
365
        self.element_list.reverse()
366
        return
367
368
    def get_dict_representation(self):
369
        dict_repr = dict()
370
        dict_repr['name'] = self.name
371
        dict_repr['element_list'] = list()
372
        for element in self.element_list:
373
            dict_repr['element_list'].append(element.get_dict_representation())
374
        return dict_repr
375
376
    @staticmethod
377
    def block_from_dict(block_dict):
378
        for ii, element_dict in enumerate(block_dict['element_list']):
379
            block_dict['element_list'][ii] = PulseBlockElement.element_from_dict(element_dict)
380
        return PulseBlock(**block_dict)
381
382
383
class PulseBlockEnsemble(object):
384
    """
385
    Represents a collection of PulseBlock objects which is called a PulseBlockEnsemble.
386
387
    This object is used as a construction plan to create one sampled file.
388
    """
389
    def __init__(self, name, block_list=None, rotating_frame=True):
390
        """
391
        The constructor for a Pulse_Block_Ensemble needs to have:
392
393
        @param str name: chosen name for the PulseBlockEnsemble
394
        @param list block_list: contains the PulseBlock names with their number of repetitions,
395
                                e.g. [(name, repetitions), (name, repetitions), ...])
396
        @param bool rotating_frame: indicates whether the phase should be preserved for all the
397
                                    functions.
398
        """
399
        # FIXME: Sanity checking needed here
400
        self.name = name
401
        self.rotating_frame = rotating_frame
402
        if isinstance(block_list, list):
403
            self.block_list = block_list
404
        else:
405
            self.block_list = list()
406
407
        # Dictionary container to store information related to the actually sampled
408
        # Waveform like pulser settings used during sampling (sample_rate, activation_config etc.)
409
        # and additional information about the discretization of the waveform (timebin positions of
410
        # the PulseBlockElement transitions etc.) as well as the names of the created waveforms.
411
        # This container will be populated during sampling and will be emptied upon deletion of the
412
        # corresponding waveforms from the pulse generator
413
        self.sampling_information = dict()
414
        # Dictionary container to store additional information about for measurement settings
415
        # (ignore_lasers, controlled_variable, alternating etc.).
416
        # This container needs to be populated by the script creating the PulseBlockEnsemble
417
        # before saving it. (e.g. in generate methods in PulsedObjectGenerator class)
418
        self.measurement_information = dict()
419
        return
420
421
    def __repr__(self):
422
        repr_str = 'PulseBlockEnsemble(name=\'{0}\', block_list={1}, rotating_frame={2})'.format(
423
            self.name, repr(self.block_list), self.rotating_frame)
424
        return repr_str
425
426
    def __str__(self):
427
        return_str = 'PulseBlockEnsemble "{0}"\n\trotating frame: {1}\n\t' \
428
                     'has been sampled: {2}\n\t<block name>\t<repetitions>\n\t'.format(
429
                         self.name, self.rotating_frame, bool(self.sampling_information))
430
        return_str += '\n\t'.join(('{0}\t{1}'.format(name, reps) for name, reps in self.block_list))
431
        return return_str
432
433
    def __eq__(self, other):
434
        if not isinstance(other, PulseBlockEnsemble):
435
            return False
436
        if self is other:
437
            return True
438
        if (self.name, self.rotating_frame) != (other.name, other.rotating_frame):
439
            return False
440
        if self.block_list != other.block_list:
441
            return False
442
        if self.measurement_information != other.measurement_information:
443
            return False
444
        return True
445
446
    def __len__(self):
447
        return len(self.block_list)
448
449
    def __getitem__(self, key):
450
        if not isinstance(key, (slice, int)):
451
            raise TypeError('PulseBlockEnsemble indices must be int or slice, not {0}'
452
                            ''.format(type(key)))
453
        return self.block_list[key]
454
455
    def __setitem__(self, key, value):
456
        if isinstance(key, int):
457
            if not isinstance(value, (tuple, list)) or len(value) != 2:
458
                raise TypeError('PulseBlockEnsemble block list entries must be a tuple or list of '
459
                                'length 2')
460
            elif not isinstance(value[0], str):
461
                raise ValueError('PulseBlockEnsemble element tuple index 0 must contain str, '
462
                                 'not {0}'.format(type(value[0])))
463
            elif not isinstance(value[1], int) or value[1] < 0:
464
                raise ValueError('PulseBlockEnsemble element tuple index 1 must contain int >= 0')
465
        elif isinstance(key, slice):
466
            for element in value:
467
                if not isinstance(element, (tuple, list)) or len(value) != 2:
468
                    raise TypeError('PulseBlockEnsemble block list entries must be a tuple or list '
469
                                    'of length 2')
470
                elif not isinstance(element[0], str):
471
                    raise ValueError('PulseBlockEnsemble element tuple index 0 must contain str, '
472
                                     'not {0}'.format(type(element[0])))
473
                elif not isinstance(element[1], int) or element[1] < 0:
474
                    raise ValueError('PulseBlockEnsemble element tuple index 1 must contain int >= '
475
                                     '0')
476
        else:
477
            raise TypeError('PulseBlockEnsemble indices must be int or slice, not {0}'
478
                            ''.format(type(key)))
479
        self.block_list[key] = tuple(value)
480
        self.sampling_information = dict()
481
        self.measurement_information = dict()
482
        return
483
484
    def __delitem__(self, key):
485
        if not isinstance(key, (slice, int)):
486
            raise TypeError('PulseBlockEnsemble indices must be int or slice, not {0}'
487
                            ''.format(type(key)))
488
489
        del self.block_list[key]
490
        self.sampling_information = dict()
491
        self.measurement_information = dict()
492
        return
493
494
    def pop(self, position=None):
495
        if len(self.block_list) == 0:
496
            raise IndexError('pop from empty PulseBlockEnsemble')
497
498
        if position is None:
499
            self.sampling_information = dict()
500
            self.measurement_information = dict()
501
            return self.block_list.pop()
502
503
        if not isinstance(position, int):
504
            raise TypeError('PulseBlockEnsemble.pop position argument expects integer, not {0}'
505
                            ''.format(type(position)))
506
507
        if position < 0:
508
            position = len(self.block_list) + position
509
510
        if len(self.block_list) <= position or position < 0:
511
            raise IndexError('PulseBlockEnsemble block list index out of range')
512
513
        self.sampling_information = dict()
514
        self.measurement_information = dict()
515
        return self.block_list.pop(position)
516
517
    def insert(self, position, element):
518
        """ Insert a (PulseBlock.name, repetitions) tuple at the given position. The old element
519
        at this position and all consecutive elements after that will be shifted to higher indices.
520
521
        @param int position: position in the element list
522
        @param tuple element: (PulseBlock name (str), repetitions (int))
523
        """
524
        if not isinstance(element, (tuple, list)) or len(element) != 2:
525
            raise TypeError('PulseBlockEnsemble block list entries must be a tuple or list of '
526
                            'length 2')
527
        elif not isinstance(element[0], str):
528
            raise ValueError('PulseBlockEnsemble element tuple index 0 must contain str, '
529
                             'not {0}'.format(type(element[0])))
530
        elif not isinstance(element[1], int) or element[1] < 0:
531
            raise ValueError('PulseBlockEnsemble element tuple index 1 must contain int >= 0')
532
533
        if position < 0:
534
            position = len(self.block_list) + position
535
        if len(self.block_list) < position or position < 0:
536
            raise IndexError('PulseBlockEnsemble block list index out of range')
537
538
        self.block_list.insert(position, tuple(element))
539
        self.sampling_information = dict()
540
        self.measurement_information = dict()
541
        return
542
543
    def append(self, element):
544
        """
545
        """
546
        self.insert(position=len(self), element=element)
547
        return
548
549
    def extend(self, iterable):
550
        for element in iterable:
551
            self.append(element=element)
552
        return
553
554
    def clear(self):
555
        del self.block_list[:]
556
        self.sampling_information = dict()
557
        self.measurement_information = dict()
558
        return
559
560
    def reverse(self):
561
        self.block_list.reverse()
562
        self.sampling_information = dict()
563
        self.measurement_information = dict()
564
        return
565
566
    def get_dict_representation(self):
567
        dict_repr = dict()
568
        dict_repr['name'] = self.name
569
        dict_repr['rotating_frame'] = self.rotating_frame
570
        dict_repr['block_list'] = self.block_list
571
        dict_repr['sampling_information'] = self.sampling_information
572
        dict_repr['measurement_information'] = self.measurement_information
573
        return dict_repr
574
575
    @staticmethod
576
    def ensemble_from_dict(ensemble_dict):
577
        new_ens = PulseBlockEnsemble(name=ensemble_dict['name'],
578
                                     block_list=ensemble_dict['block_list'],
579
                                     rotating_frame=ensemble_dict['rotating_frame'])
580
        new_ens.sampling_information = ensemble_dict['sampling_information']
581
        new_ens.measurement_information = ensemble_dict['measurement_information']
582
        return new_ens
583
584
585
class PulseSequence(object):
586
    """
587
    Higher order object for sequence capability.
588
589
    Represents a playback procedure for a number of PulseBlockEnsembles. Unused for pulse
590
    generator hardware without sequencing functionality.
591
    """
592
    __default_seq_params = {'repetitions': 0,
593
                            'go_to': -1,
594
                            'event_jump_to': -1,
595
                            'event_trigger': 'OFF',
596
                            'wait_for': 'OFF',
597
                            'flag_trigger': 'OFF',
598
                            'flag_high': 'OFF'}
599
600
    def __init__(self, name, ensemble_list=None, rotating_frame=False):
601
        """
602
        The constructor for a PulseSequence objects needs to have:
603
604
        @param str name: the actual name of the sequence
605
        @param list ensemble_list: list containing a tuple of two entries:
606
                                          [(PulseBlockEnsemble name, seq_param),
607
                                           (PulseBlockEnsemble name, seq_param), ...]
608
                                          The seq_param is a dictionary, where the various sequence
609
                                          parameters are saved with their keywords and the
610
                                          according parameter (as item).
611
                                          Available parameters are:
612
                                          'repetitions': The number of repetitions for that sequence
613
                                                         step. (Default 0)
614
                                                         0 meaning the step is played once.
615
                                                         Set to -1 for infinite looping.
616
                                          'go_to':   The sequence step index to jump to after
617
                                                     having played all repetitions. (Default -1)
618
                                                     Indices starting at 1 for first step.
619
                                                     Set to 0 or -1 to follow up with the next step.
620
                                          'event_jump_to': The sequence step to jump to
621
                                                           (starting from 1) in case of a trigger
622
                                                           event (see event_trigger).
623
                                                           Setting it to 0 or -1 means jump to next
624
                                                           step. Ignored if event_trigger is 'OFF'.
625
                                          'event_trigger': The trigger input to listen to in order
626
                                                           to perform sequence jumps. Set to 'OFF'
627
                                                           (default) in order to ignore triggering.
628
                                          'wait_for': The trigger input to wait for before playing
629
                                                      this sequence step. Set to 'OFF' (default)
630
                                                      in order to play the current step immediately.
631
                                          'flag_trigger': The flag to trigger when this sequence
632
                                                          step starts playing. Select 'OFF'
633
                                                          (default) for no flag trigger.
634
                                          'flag_high': The flag to set to high while this step is
635
                                                       playing. Select 'OFF' (default) to set all
636
                                                       flags to low.
637
638
                                          If only 'repetitions' are in the dictionary, then the dict
639
                                          will look like:
640
                                            seq_param = {'repetitions': 41}
641
                                          and so the respective sequence step will play 42 times.
642
        @param bool rotating_frame: indicates, whether the phase has to be preserved in all
643
                                    analog signals ACROSS different waveforms
644
        """
645
        self.name = name
646
        self.rotating_frame = rotating_frame
647
        self.ensemble_list = list() if ensemble_list is None else ensemble_list
648
        self.is_finite = True
649
        self.refresh_parameters()
650
651
        # self.sampled_ensembles = OrderedDict()
652
        # Dictionary container to store information related to the actually sampled
653
        # Waveforms like pulser settings used during sampling (sample_rate, activation_config etc.)
654
        # and additional information about the discretization of the waveform (timebin positions of
655
        # the PulseBlockElement transitions etc.)
656
        # This container is not necessary for the sampling process but serves only the purpose of
657
        # holding optional information for different modules.
658
        self.sampling_information = dict()
659
        # Dictionary container to store additional information about for measurement settings
660
        # (ignore_lasers, controlled_values, alternating etc.).
661
        # This container needs to be populated by the script creating the PulseSequence
662
        # before saving it.
663
        self.measurement_information = dict()
664
        return
665
666
    def refresh_parameters(self):
667
        self.is_finite = True
668
        for ensemble_name, params in self.ensemble_list:
669
            if params['repetitions'] < 0:
670
                self.is_finite = False
671
                break
672
        return
673
674
    def __repr__(self):
675
        repr_str = 'PulseSequence(name=\'{0}\', ensemble_list={1}, rotating_frame={2})'.format(
676
            self.name, self.ensemble_list, self.rotating_frame)
677
        return repr_str
678
679
    def __str__(self):
680
        return_str = 'PulseSequence "{0}"\n\trotating frame: {1}\n\t' \
681
                     'has finite length: {2}\n\thas been sampled: {3}\n\t<ensemble name>\t' \
682
                     '<sequence parameters>\n\t'.format(self.name,
683
                                                        self.rotating_frame,
684
                                                        self.is_finite,
685
                                                        bool(self.sampling_information))
686
        return_str += '\n\t'.join(('{0}\t{1}'.format(name, param) for name, param in self))
687
        return return_str
688
689
    def __eq__(self, other):
690
        if not isinstance(other, PulseSequence):
691
            return False
692
        if self is other:
693
            return True
694
        if (self.name, self.rotating_frame, self.is_finite) != (other.name, other.rotating_frame, other.is_finite):
695
            return False
696
        if self.ensemble_list != other.ensemble_list:
697
            return False
698
        if self.measurement_information != other.measurement_information:
699
            return False
700
        return True
701
702
    def __len__(self):
703
        return len(self.ensemble_list)
704
705
    def __getitem__(self, key):
706
        if not isinstance(key, (slice, int)):
707
            raise TypeError('PulseSequence indices must be int or slice, not {0}'.format(type(key)))
708
        return self.ensemble_list[key]
709
710
    def __setitem__(self, key, value):
711
        stage_refresh = False
712
        if isinstance(key, int):
713
            if isinstance(value, str):
714
                value = (value, self.__default_seq_params.copy())
715
            if not isinstance(value, (tuple, list)) or len(value) != 2:
716
                raise TypeError('PulseSequence ensemble list entries must be a tuple or list of '
717
                                'length 2')
718
            elif not isinstance(value[0], str):
719
                raise ValueError('PulseSequence element tuple index 0 must contain str, not {0}'
720
                                 ''.format(type(value[0])))
721
            elif not isinstance(value[1], dict):
722
                raise ValueError('PulseSequence element tuple index 1 must contain dict, not {0}'
723
                                 ''.format(type(value[1])))
724
725
            if value[1]['repetitions'] < 0:
726
                self.is_finite = False
727
            elif not self.is_finite and self[key][1]['repetitions'] < 0:
728
                stage_refresh = True
729
        elif isinstance(key, slice):
730
            if isinstance(value[0], str):
731
                tmp_value = list()
732
                for element in value:
733
                    tmp_value.append((element, self.__default_seq_params.copy()))
734
                value = tmp_value
735
            for element in value:
736
                if not isinstance(element, (tuple, list)) or len(value) != 2:
737
                    raise TypeError('PulseSequence block list entries must be a tuple or list '
738
                                    'of length 2')
739
                elif not isinstance(element[0], str):
740
                    raise ValueError('PulseSequence element tuple index 0 must contain str, not {0}'
741
                                     ''.format(type(element[0])))
742
                elif not isinstance(element[1], dict):
743
                    raise ValueError('PulseSequence element tuple index 1 must contain dict, not '
744
                                     '{0}'.format(type(element[1])))
745
746
                if element[1]['repetitions'] < 0:
747
                    self.is_finite = False
748
                elif not self.is_finite:
749
                    stage_refresh = True
750
        else:
751
            raise TypeError('PulseSequence indices must be int or slice, not {0}'.format(type(key)))
752
        self.ensemble_list[key] = tuple(value)
753
        self.sampling_information = dict()
754
        self.measurement_information = dict()
755
        if stage_refresh:
756
            self.refresh_parameters()
757
        return
758
759
    def __delitem__(self, key):
760
        if isinstance(key, slice):
761
            stage_refresh = False
762
            for element in self.ensemble_list[key]:
763
                if element[1]['repetitions'] < 0:
764
                    stage_refresh = True
765
                    break
766
        elif isinstance(key, int):
767
            stage_refresh = self.ensemble_list[key][1]['repetitions'] < 0
768
        else:
769
            raise TypeError('PulseSequence indices must be int or slice, not {0}'.format(type(key)))
770
        del self.ensemble_list[key]
771
        self.sampling_information = dict()
772
        self.measurement_information = dict()
773
        if stage_refresh:
774
            self.refresh_parameters()
775
        return
776
777
    def pop(self, position=None):
778
        if len(self.ensemble_list) == 0:
779
            raise IndexError('pop from empty PulseSequence')
780
781
        if position is None:
782
            position = len(self.ensemble_list) - 1
783
784
        if not isinstance(position, int):
785
            raise TypeError('PulseSequence.pop position argument expects integer, not {0}'
786
                            ''.format(type(position)))
787
788
        if position < 0:
789
            position = len(self.ensemble_list) + position
790
791
        if len(self.ensemble_list) <= position or position < 0:
792
            raise IndexError('PulseSequence ensemble list index out of range')
793
794
        self.sampling_information = dict()
795
        self.measurement_information = dict()
796
        if self.ensemble_list[-1][1]['repetitions'] < 0:
797
            popped_element = self.ensemble_list.pop(position)
798
            self.refresh_parameters()
799
            return popped_element
800
        return self.ensemble_list.pop(position)
801
802
    def insert(self, position, element):
803
        """ Insert a (PulseSequence.name, parameters) tuple at the given position. The old element
804
        at this position and all consecutive elements after that will be shifted to higher indices.
805
806
        @param int position: position in the ensemble list
807
        @param tuple|str element: PulseBlock name (str)[, seq_parameters (dict)]
808
        """
809
        if isinstance(element, str):
810
            element = (element, self.__default_seq_params.copy())
811
812
        if not isinstance(element, (tuple, list)) or len(element) != 2:
813
            raise TypeError('PulseSequence ensemble list entries must be a tuple or list of '
814
                            'length 2')
815
        elif not isinstance(element[0], str):
816
            raise ValueError('PulseSequence element tuple index 0 must contain str, '
817
                             'not {0}'.format(type(element[0])))
818
        elif not isinstance(element[1], dict):
819
            raise ValueError('PulseSequence element tuple index 1 must contain dict')
820
821
        if position < 0:
822
            position = len(self.ensemble_list) + position
823
        if len(self.ensemble_list) < position or position < 0:
824
            raise IndexError('PulseSequence ensemble list index out of range')
825
826
        self.ensemble_list.insert(position, tuple(element))
827
        if element[1]['repetitions'] < 0:
828
            self.is_finite = False
829
        self.sampling_information = dict()
830
        self.measurement_information = dict()
831
        return
832
833
    def append(self, element):
834
        """
835
        """
836
        self.insert(position=len(self), element=element)
837
        return
838
839
    def extend(self, iterable):
840
        for element in iterable:
841
            self.append(element=element)
842
        return
843
844
    def clear(self):
845
        del self.ensemble_list[:]
846
        self.sampling_information = dict()
847
        self.measurement_information = dict()
848
        self.is_finite = True
849
        return
850
851
    def reverse(self):
852
        self.ensemble_list.reverse()
853
        self.sampling_information = dict()
854
        self.measurement_information = dict()
855
        return
856
857
    def get_dict_representation(self):
858
        dict_repr = dict()
859
        dict_repr['name'] = self.name
860
        dict_repr['rotating_frame'] = self.rotating_frame
861
        dict_repr['ensemble_list'] = self.ensemble_list
862
        dict_repr['sampling_information'] = self.sampling_information
863
        dict_repr['measurement_information'] = self.measurement_information
864
        return dict_repr
865
866
    @staticmethod
867
    def sequence_from_dict(sequence_dict):
868
        new_seq = PulseSequence(name=sequence_dict['name'],
869
                                ensemble_list=sequence_dict['ensemble_list'],
870
                                rotating_frame=sequence_dict['rotating_frame'])
871
        new_seq.sampling_information = sequence_dict['sampling_information']
872
        new_seq.measurement_information = sequence_dict['measurement_information']
873
        return new_seq
874
875
876
class PredefinedGeneratorBase:
877
    """
878
    Base class for PulseObjectGenerator and predefined generator classes containing the actual
879
    "generate_"-methods.
880
881
    This class holds a protected reference to the SequenceGeneratorLogic and provides read-only
882
    access via properties to various attributes of the logic module.
883
    SequenceGeneratorLogic logger is also accessible via this base class and can be used as in any
884
    qudi module (e.g. self.log.error(...)).
885
    Also provides helper methods to simplify sequence/ensemble generation.
886
    """
887
    def __init__(self, sequencegeneratorlogic):
888
        # Keep protected reference to the SequenceGeneratorLogic
889
        self.__sequencegeneratorlogic = sequencegeneratorlogic
890
891
    @property
892
    def log(self):
893
        return self.__sequencegeneratorlogic.log
894
895
    @property
896
    def pulse_generator_settings(self):
897
        return self.__sequencegeneratorlogic.pulse_generator_settings
898
899
    @property
900
    def generation_parameters(self):
901
        return self.__sequencegeneratorlogic.generation_parameters
902
903
    @property
904
    def channel_set(self):
905
        channels = self.pulse_generator_settings.get('activation_config')
906
        if channels is None:
907
            channels = ('', set())
908
        return channels[1]
909
910
    @property
911
    def analog_channels(self):
912
        return {chnl for chnl in self.channel_set if chnl.startswith('a')}
913
914
    @property
915
    def digital_channels(self):
916
        return {chnl for chnl in self.channel_set if chnl.startswith('d')}
917
918
    @property
919
    def laser_channel(self):
920
        return self.generation_parameters.get('laser_channel')
921
922
    @property
923
    def sync_channel(self):
924
        channel = self.generation_parameters.get('sync_channel')
925
        return None if channel == '' else channel
926
927
    @property
928
    def gate_channel(self):
929
        channel = self.generation_parameters.get('gate_channel')
930
        return None if channel == '' else channel
931
932
    @property
933
    def analog_trigger_voltage(self):
934
        return self.generation_parameters.get('analog_trigger_voltage')
935
936
    @property
937
    def laser_delay(self):
938
        return self.generation_parameters.get('laser_delay')
939
940
    @property
941
    def microwave_channel(self):
942
        channel = self.generation_parameters.get('microwave_channel')
943
        return None if channel == '' else channel
944
945
    @property
946
    def microwave_frequency(self):
947
        return self.generation_parameters.get('microwave_frequency')
948
949
    @property
950
    def microwave_amplitude(self):
951
        return self.generation_parameters.get('microwave_amplitude')
952
953
    @property
954
    def laser_length(self):
955
        return self.generation_parameters.get('laser_length')
956
957
    @property
958
    def wait_time(self):
959
        return self.generation_parameters.get('wait_time')
960
961
    @property
962
    def rabi_period(self):
963
        return self.generation_parameters.get('rabi_period')
964
965
    ################################################################################################
966
    #                                   Helper methods                                          ####
967
    ################################################################################################
968
    def _get_idle_element(self, length, increment):
969
        """
970
        Creates an idle pulse PulseBlockElement
971
972
        @param float length: idle duration in seconds
973
        @param float increment: idle duration increment in seconds
974
975
        @return: PulseBlockElement, the generated idle element
976
        """
977
        # Create idle element
978
        return PulseBlockElement(
979
            init_length_s=length,
980
            increment_s=increment,
981
            pulse_function={chnl: SamplingFunctions.Idle() for chnl in self.analog_channels},
982
            digital_high={chnl: False for chnl in self.digital_channels})
983
984
    def _get_trigger_element(self, length, increment, channels):
985
        """
986
        Creates a trigger PulseBlockElement
987
988
        @param float length: trigger duration in seconds
989
        @param float increment: trigger duration increment in seconds
990
        @param str|list channels: The pulser channel(s) to be triggered.
991
992
        @return: PulseBlockElement, the generated trigger element
993
        """
994
        if isinstance(channels, str):
995
            channels = [channels]
996
997
        # input params for element generation
998
        pulse_function = {chnl: SamplingFunctions.Idle() for chnl in self.analog_channels}
999
        digital_high = {chnl: False for chnl in self.digital_channels}
1000
1001
        # Determine analogue or digital trigger channel and set channels accordingly.
1002
        for channel in channels:
1003
            if channel.startswith('d'):
1004
                digital_high[channel] = True
1005
            else:
1006
                pulse_function[channel] = SamplingFunctions.DC(voltage=self.analog_trigger_voltage)
1007
1008
        # return trigger element
1009
        return PulseBlockElement(init_length_s=length,
1010
                                 increment_s=increment,
1011
                                 pulse_function=pulse_function,
1012
                                 digital_high=digital_high)
1013
1014
    def _get_laser_element(self, length, increment):
1015
        """
1016
        Creates laser trigger PulseBlockElement
1017
1018
        @param float length: laser pulse duration in seconds
1019
        @param float increment: laser pulse duration increment in seconds
1020
1021
        @return: PulseBlockElement, two elements for laser and gate trigger (delay element)
1022
        """
1023
        return self._get_trigger_element(length=length,
1024
                                         increment=increment,
1025
                                         channels=self.laser_channel)
1026
1027
    def _get_laser_gate_element(self, length, increment):
1028
        """
1029
        """
1030
        laser_gate_element = self._get_laser_element(length=length,
1031
                                                     increment=increment)
1032
        if self.gate_channel:
1033
            if self.gate_channel.startswith('d'):
1034
                laser_gate_element.digital_high[self.gate_channel] = True
1035
            else:
1036
                laser_gate_element.pulse_function[self.gate_channel] = SamplingFunctions.DC(
1037
                    voltage=self.analog_trigger_voltage)
1038
        return laser_gate_element
1039
1040
    def _get_delay_element(self):
1041
        """
1042
        Creates an idle element of length of the laser delay
1043
1044
        @return PulseBlockElement: The delay element
1045
        """
1046
        return self._get_idle_element(length=self.laser_delay,
1047
                                      increment=0)
1048
1049
    def _get_delay_gate_element(self):
1050
        """
1051
        Creates a gate trigger of length of the laser delay.
1052
        If no gate channel is specified will return a simple idle element.
1053
1054
        @return PulseBlockElement: The delay element
1055
        """
1056
        if self.gate_channel:
1057
            return self._get_trigger_element(length=self.laser_delay,
1058
                                             increment=0,
1059
                                             channels=self.gate_channel)
1060
        else:
1061
            return self._get_delay_element()
1062
1063
    def _get_sync_element(self):
1064
        """
1065
1066
        """
1067
        return self._get_trigger_element(length=50e-9,
1068
                                         increment=0,
1069
                                         channels=self.sync_channel)
1070
1071
    def _get_mw_element(self, length, increment, amp=None, freq=None, phase=None):
1072
        """
1073
        Creates a MW pulse PulseBlockElement
1074
1075
        @param float length: MW pulse duration in seconds
1076
        @param float increment: MW pulse duration increment in seconds
1077
        @param float freq: MW frequency in case of analogue MW channel in Hz
1078
        @param float amp: MW amplitude in case of analogue MW channel in V
1079
        @param float phase: MW phase in case of analogue MW channel in deg
1080
1081
        @return: PulseBlockElement, the generated MW element
1082
        """
1083
        if self.microwave_channel.startswith('d'):
1084
            mw_element = self._get_trigger_element(
1085
                length=length,
1086
                increment=increment,
1087
                channels=self.microwave_channel)
1088
        else:
1089
            mw_element = self._get_idle_element(
1090
                length=length,
1091
                increment=increment)
1092
            mw_element.pulse_function[self.microwave_channel] = SamplingFunctions.Sin(
1093
                amplitude=amp,
1094
                frequency=freq,
1095
                phase=phase)
1096
        return mw_element
1097
1098
    def _get_multiple_mw_element(self, length, increment, amps=None, freqs=None, phases=None):
1099
        """
1100
        Creates single, double or triple sine mw element.
1101
1102
        @param float length: MW pulse duration in seconds
1103
        @param float increment: MW pulse duration increment in seconds
1104
        @param amps: list containing the amplitudes
1105
        @param freqs: list containing the frequencies
1106
        @param phases: list containing the phases
1107
        @return: PulseBlockElement, the generated MW element
1108
        """
1109
        if isinstance(amps, (int, float)):
1110
            amps = [amps]
1111
        if isinstance(freqs, (int, float)):
1112
            freqs = [freqs]
1113
        if isinstance(phases, (int, float)):
1114
            phases = [phases]
1115
1116
        if self.microwave_channel.startswith('d'):
1117
            mw_element = self._get_trigger_element(
1118
                length=length,
1119
                increment=increment,
1120
                channels=self.microwave_channel)
1121
        else:
1122
            mw_element = self._get_idle_element(
1123
                length=length,
1124
                increment=increment)
1125
1126
            sine_number = min(len(amps), len(freqs), len(phases))
1127
1128
            if sine_number < 2:
1129
                mw_element.pulse_function[self.microwave_channel] = SamplingFunctions.Sin(
1130
                    amplitude=amps[0],
1131
                    frequency=freqs[0],
1132
                    phase=phases[0])
1133
            elif sine_number == 2:
1134
                mw_element.pulse_function[self.microwave_channel] = SamplingFunctions.DoubleSin(
1135
                    amplitude_1=amps[0],
1136
                    amplitude_2=amps[1],
1137
                    frequency_1=freqs[0],
1138
                    frequency_2=freqs[1],
1139
                    phase_1=phases[0],
1140
                    phase_2=phases[1])
1141
            else:
1142
                mw_element.pulse_function[self.microwave_channel] = SamplingFunctions.TripleSin(
1143
                    amplitude_1=amps[0],
1144
                    amplitude_2=amps[1],
1145
                    amplitude_3=amps[2],
1146
                    frequency_1=freqs[0],
1147
                    frequency_2=freqs[1],
1148
                    frequency_3=freqs[2],
1149
                    phase_1=phases[0],
1150
                    phase_2=phases[1],
1151
                    phase_3=phases[2])
1152
        return mw_element
1153
1154
    def _get_mw_laser_element(self, length, increment, amp=None, freq=None, phase=None):
1155
        """
1156
1157
        @param length:
1158
        @param increment:
1159
        @param amp:
1160
        @param freq:
1161
        @param phase:
1162
        @return:
1163
        """
1164
        mw_laser_element = self._get_mw_element(length=length,
1165
                                                increment=increment,
1166
                                                amp=amp,
1167
                                                freq=freq,
1168
                                                phase=phase)
1169
        if self.laser_channel.startswith('d'):
1170
            mw_laser_element.digital_high[self.laser_channel] = True
1171
        else:
1172
            mw_laser_element.pulse_function[self.laser_channel] = SamplingFunctions.DC(
1173
                voltage=self.analog_trigger_voltage)
1174
        return mw_laser_element
1175
1176
    def _get_ensemble_count_length(self, ensemble, created_blocks):
1177
        """
1178
1179
        @param ensemble:
1180
        @param created_blocks:
1181
        @return:
1182
        """
1183
        if self.gate_channel:
1184
            length = self.laser_length + self.laser_delay
1185
        else:
1186
            blocks = {block.name: block for block in created_blocks}
1187
            length = 0.0
1188
            for block_name, reps in ensemble.block_list:
1189
                length += blocks[block_name].init_length_s * (reps + 1)
1190
                length += blocks[block_name].increment_s * ((reps ** 2 + reps) / 2)
1191
        return length
1192
1193
1194
class PulseObjectGenerator(PredefinedGeneratorBase):
1195
    """
1196
1197
    """
1198
    def __init__(self, sequencegeneratorlogic):
1199
        # Initialize base class
1200
        super().__init__(sequencegeneratorlogic)
1201
1202
        # dictionary containing references to all generation methods imported from generator class
1203
        # modules. The keys are the method names excluding the prefix "generate_".
1204
        self._generate_methods = dict()
1205
        # nested dictionary with keys being the generation method names and values being a
1206
        # dictionary containing all keyword arguments as keys with their default value
1207
        self._generate_method_parameters = dict()
1208
1209
        # import path for generator modules from default dir (logic.predefined_generate_methods)
1210
        path_list = [os.path.join(get_main_dir(), 'logic', 'pulsed', 'predefined_generate_methods')]
1211
        # import path for generator modules from non-default directory if a path has been given
1212
        if isinstance(sequencegeneratorlogic.additional_methods_dir, str):
1213
            path_list.append(sequencegeneratorlogic.additional_methods_dir)
1214
1215
        # Import predefined generator modules and get a list of generator classes
1216
        generator_classes = self.__import_external_generators(paths=path_list)
1217
1218
        # create an instance of each class and put them in a temporary list
1219
        generator_instances = [cls(sequencegeneratorlogic) for cls in generator_classes]
1220
1221
        # add references to all generate methods in each instance to a dict
1222
        self.__populate_method_dict(instance_list=generator_instances)
1223
1224
        # populate parameters dictionary from generate method signatures
1225
        self.__populate_parameter_dict()
1226
1227
    @property
1228
    def predefined_generate_methods(self):
1229
        return self._generate_methods
1230
1231
    @property
1232
    def predefined_method_parameters(self):
1233
        return self._generate_method_parameters.copy()
1234
1235 View Code Duplication
    def __import_external_generators(self, paths):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1236
        """
1237
        Helper method to import all modules from directories contained in paths.
1238
        Find all classes in those modules that inherit exclusively from PredefinedGeneratorBase
1239
        class and return a list of them.
1240
1241
        @param iterable paths: iterable containing paths to import modules from
1242
        @return list: A list of imported valid generator classes
1243
        """
1244
        class_list = list()
1245
        for path in paths:
1246
            if not os.path.exists(path):
1247
                self.log.error('Unable to import generate methods from "{0}".\n'
1248
                               'Path does not exist.'.format(path))
1249
                continue
1250
            # Get all python modules to import from.
1251
            # The assumption is that in the path, there are *.py files,
1252
            # which contain only generator classes!
1253
            module_list = [name[:-3] for name in os.listdir(path) if
1254
                           os.path.isfile(os.path.join(path, name)) and name.endswith('.py')]
1255
1256
            # append import path to sys.path
1257
            sys.path.append(path)
1258
1259
            # Go through all modules and create instances of each class found.
1260
            for module_name in module_list:
1261
                # import module
1262
                mod = importlib.import_module('{0}'.format(module_name))
1263
                importlib.reload(mod)
1264
                # get all generator class references defined in the module
1265
                tmp_list = [m[1] for m in inspect.getmembers(mod, self.is_generator_class)]
1266
                # append to class_list
1267
                class_list.extend(tmp_list)
1268
        return class_list
1269
1270
    def __populate_method_dict(self, instance_list):
1271
        """
1272
        Helper method to populate the dictionaries containing all references to callable generate
1273
        methods contained in generator instances passed to this method.
1274
1275
        @param list instance_list: List containing instances of generator classes
1276
        """
1277
        self._generate_methods = dict()
1278
        for instance in instance_list:
1279
            for method_name, method_ref in inspect.getmembers(instance, inspect.ismethod):
1280
                if method_name.startswith('generate_'):
1281
                    self._generate_methods[method_name[9:]] = method_ref
1282
        return
1283
1284
    def __populate_parameter_dict(self):
1285
        """
1286
        Helper method to populate the dictionary containing all possible keyword arguments from all
1287
        generate methods.
1288
        """
1289
        self._generate_method_parameters = dict()
1290
        for method_name, method in self._generate_methods.items():
1291
            method_signature = inspect.signature(method)
1292
            param_dict = dict()
1293
            for name, param in method_signature.parameters.items():
1294
                param_dict[name] = None if param.default is param.empty else param.default
1295
1296
            self._generate_method_parameters[method_name] = param_dict
1297
        return
1298
1299
    @staticmethod
1300
    def is_generator_class(obj):
1301
        """
1302
        Helper method to check if an object is a valid generator class.
1303
1304
        @param object obj: object to check
1305
        @return bool: True if obj is a valid generator class, False otherwise
1306
        """
1307
        if inspect.isclass(obj):
1308
            return PredefinedGeneratorBase in obj.__bases__ and len(obj.__bases__) == 1
1309
        return False
1310
1311
1312
1313