Completed
Pull Request — master (#384)
by
unknown
01:25
created

PulseExtractor.__init__()   C

Complexity

Conditions 9

Size

Total Lines 46

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 9
dl 0
loc 46
rs 6.4339
c 3
b 0
f 0
1
# -*- coding: utf-8 -*-
2
"""
3
This file contains the Qudi helper classes for the extraction of laser pulses.
4
5
Qudi is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Qudi is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with Qudi. If not, see <http://www.gnu.org/licenses/>.
17
18
Copyright (c) the Qudi Developers. See the COPYRIGHT.txt file at the
19
top-level directory of this distribution and at <https://github.com/Ulm-IQO/qudi/>
20
"""
21
22
import os
23
import sys
24
import inspect
25
import importlib
26
27
from core.util.modules import get_main_dir
28
29
30 View Code Duplication
class PulseExtractorBase:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
31
    """
32
    All extractor classes to import from must inherit exclusively from this base class.
33
    This base class enables extractor classes masked read-only access to settings from
34
    PulsedMeasurementLogic.
35
36
    See BasicPulseExtractor class for an example usage.
37
    """
38
    def __init__(self, pulsedmeasurementlogic):
39
        self.__pulsedmeasurementlogic = pulsedmeasurementlogic
40
41
    @property
42
    def is_gated(self):
43
        return self.__pulsedmeasurementlogic.fast_counter_settings.get('is_gated')
44
45
    @property
46
    def measurement_settings(self):
47
        return self.__pulsedmeasurementlogic.measurement_settings
48
49
    @property
50
    def sampling_information(self):
51
        return self.__pulsedmeasurementlogic.sampling_information
52
53
    @property
54
    def fast_counter_settings(self):
55
        return self.__pulsedmeasurementlogic.fast_counter_settings
56
57
    @property
58
    def log(self):
59
        return self.__pulsedmeasurementlogic.log
60
61
62
class PulseExtractor(PulseExtractorBase):
63
    """
64
    Management class to automatically combine and interface extraction methods and associated
65
    parameters from extractor classes defined in several modules.
66
67
    Extractor class to import from must comply to the following rules:
68
    1) Exclusive inheritance from PulseExtractorBase class
69
    2) No direct access to PulsedMeasurementLogic instance except through properties defined in
70
       base class (read-only access)
71
    3) Extraction methods must be bound instance methods
72
    4) Extraction methods must be named starting with "ungated_" or "gated_" accordingly
73
    5) Extraction methods must have as first argument "count_data"
74
    6) Apart from "count_data" extraction methods must have exclusively keyword arguments with
75
       default values of the right data type. (e.g. differentiate between 42 (int) and 42.0 (float))
76
    7) Make sure that no two extraction methods in any module share a keyword argument of different
77
       default data type.
78
    8) The keyword "method" must not be used in the extraction method parameters
79
80
    See BasicPulseExtractor class for an example usage.
81
    """
82
83
    def __init__(self, pulsedmeasurementlogic):
84
        # Init base class
85
        super().__init__(pulsedmeasurementlogic)
86
87
        # Dictionaries holding references to the extraction methods
88
        self._gated_extraction_methods = dict()
89
        self._ungated_extraction_methods = dict()
90
        # dictionary containing all possible parameters that can be used by the extraction methods
91
        self._parameters = dict()
92
        # Currently selected extraction method
93
        self._current_extraction_method = None
94
95
        # import path for extraction modules from default directory (logic.pulse_extraction_methods)
96
        path_list = [os.path.join(get_main_dir(), 'logic', 'pulsed', 'pulse_extraction_methods')]
97
        # import path for extraction modules from non-default directory if a path has been given
98
        if isinstance(pulsedmeasurementlogic.extraction_import_path, str):
99
            path_list.append(pulsedmeasurementlogic.extraction_import_path)
100
101
        # Import extraction modules and get a list of extractor classes
102
        extractor_classes = self.__import_external_extractors(paths=path_list)
103
104
        # create an instance of each class and put them in a temporary list
105
        extractor_instances = [cls(pulsedmeasurementlogic) for cls in extractor_classes]
106
107
        # add references to all extraction methods in each instance to a dict
108
        self.__populate_method_dicts(instance_list=extractor_instances)
109
110
        # populate "_parameters" dictionary from extraction method signatures
111
        self.__populate_parameter_dict()
112
113
        # Set default extraction method
114
        if self.is_gated:
115
            self._current_extraction_method = sorted(self._gated_extraction_methods)[0]
116
        else:
117
            self._current_extraction_method = sorted(self._ungated_extraction_methods)[0]
118
119
        # Update from parameter_dict if handed over
120
        if isinstance(pulsedmeasurementlogic.extraction_parameters, dict):
121
            # Delete unused parameters
122
            params = [p for p in pulsedmeasurementlogic.extraction_parameters if
123
                      p not in self._parameters and p != 'method']
124
            for param in params:
125
                del pulsedmeasurementlogic.extraction_parameters[param]
126
            # Update parameter dict and current method
127
            self.extraction_settings = pulsedmeasurementlogic.extraction_parameters
128
        return
129
130
    @property
131
    def extraction_settings(self):
132
        """
133
        This property holds all parameters needed for the currently selected extraction_method as
134
        well as the currently selected method name.
135
136
        @return dict: dictionary with keys being the parameter name and values being the parameter
137
        """
138
        # Get reference to the extraction method
139
        if self.is_gated:
140
            method = self._gated_extraction_methods.get(self._current_extraction_method)
141
        else:
142
            method = self._ungated_extraction_methods.get(self._current_extraction_method)
143
144
        # Get keyword arguments for the currently selected method
145
        settings_dict = self._get_extraction_method_kwargs(method)
146
147
        # Attach current extraction method name
148
        settings_dict['method'] = self._current_extraction_method
149
        return settings_dict
150
151 View Code Duplication
    @extraction_settings.setter
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
152
    def extraction_settings(self, settings_dict):
153
        """
154
        Update parameters contained in self._parameters by values in settings_dict.
155
        Also sets the current extraction method by passing its name using key "method".
156
        Parameters not included in self._parameters (except "method") will be ignored.
157
158
        @param dict settings_dict: dictionary containing the parameters to set (name, value)
159
        """
160
        if not isinstance(settings_dict, dict):
161
            return
162
163
        # go through all key-value pairs in settings_dict and update self._parameters and
164
        # self._current_extraction_method accordingly. Ignore unknown parameters.
165
        for parameter, value in settings_dict.items():
166
            if parameter == 'method':
167
                if (value in self._gated_extraction_methods and self.is_gated) or (
168
                        value in self._ungated_extraction_methods and not self.is_gated):
169
                    self._current_extraction_method = value
170
                else:
171
                    self.log.error('Extraction method "{0}" could not be found in PulseExtractor.'
172
                                   ''.format(value))
173
            elif parameter in self._parameters:
174
                self._parameters[parameter] = value
175
            else:
176
                self.log.warning('No extraction parameter "{0}" found in PulseExtractor.\n'
177
                                 'Parameter will be ignored.'.format(parameter))
178
        return
179
180
    @property
181
    def extraction_methods(self):
182
        """
183
        Return available extraction methods depending on if the fast counter is gated or not.
184
185
        @return dict: Dictionary with keys being the method names and values being the methods.
186
        """
187
        if self.is_gated:
188
            return self._gated_extraction_methods
189
        else:
190
            return self._ungated_extraction_methods
191
192
    @property
193
    def full_settings_dict(self):
194
        """
195
        Returns the full set of parameters for all methods as well as the currently selected method
196
        in order to store them in a StatusVar in PulsedMeasurementLogic.
197
198
        @return dict: full set of parameters and currently selected extraction method.
199
        """
200
        settings_dict = self._parameters.copy()
201
        settings_dict['method'] = self._current_extraction_method
202
        return settings_dict
203
204
    def extract_laser_pulses(self, count_data):
205
        """
206
        Wrapper method to call the currently selected extraction method with count_data and the
207
        appropriate keyword arguments.
208
209
        @param numpy.ndarray count_data: 1D (ungated) or 2D (gated) numpy array (dtype='int64')
210
                                         containing the timetrace to extract laser pulses from.
211
        @return dict: result dictionary of the extraction method
212
        """
213
        if count_data.ndim > 1 and not self.is_gated:
214
            self.log.error('"is_gated" flag is set to False but the count data to extract laser '
215
                           'pulses from is in the format of a gated timetrace (2D numpy array).')
216
        elif count_data.ndim == 1 and self.is_gated:
217
            self.log.error('"is_gated" flag is set to True but the count data to extract laser '
218
                           'pulses from is in the format of an ungated timetrace (1D numpy array).')
219
220
        if self.is_gated:
221
            extraction_method = self._gated_extraction_methods[self._current_extraction_method]
222
        else:
223
            extraction_method = self._ungated_extraction_methods[self._current_extraction_method]
224
        kwargs = self._get_extraction_method_kwargs(extraction_method)
225
        return extraction_method(count_data=count_data, **kwargs)
226
227 View Code Duplication
    def _get_extraction_method_kwargs(self, method):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
228
        """
229
        Get the proper values for keyword arguments other than "count_data" for <method>.
230
        Try to take the values from self._parameters. If the keyword is missing in the dictionary,
231
        take the default values from the method signature.
232
233
        @param method: reference to a callable extraction method
234
        @return dict: A dictionary containing the argument keywords for <method> and corresponding
235
                      values from self._parameters.
236
        """
237
        kwargs_dict = dict()
238
        method_signature = inspect.signature(method)
239
        for name in method_signature.parameters.keys():
240
            if name == 'count_data':
241
                continue
242
243
            default = method_signature.parameters[name].default
244
            recalled = self._parameters.get(name)
245
246
            if recalled is not None and type(recalled) == type(default):
247
                kwargs_dict[name] = recalled
248
            else:
249
                kwargs_dict[name] = default
250
        return kwargs_dict
251
252 View Code Duplication
    def __import_external_extractors(self, paths):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
253
        """
254
        Helper method to import all modules from directories contained in paths.
255
        Find all classes in those modules that inherit exclusively from PulseExtractorBase class
256
        and return a list of them.
257
258
        @param iterable paths: iterable containing paths to import modules from
259
        @return list: A list of imported valid extractor classes
260
        """
261
        class_list = list()
262
        for path in paths:
263
            if not os.path.exists(path):
264
                self.log.error('Unable to import extraction methods from "{0}".\n'
265
                               'Path does not exist.'.format(path))
266
                continue
267
            # Get all python modules to import from.
268
            # The assumption is that in the directory pulse_extraction_methods, there are
269
            # *.py files, which contain only extractor classes!
270
            module_list = [name[:-3] for name in os.listdir(path) if
271
                           os.path.isfile(os.path.join(path, name)) and name.endswith('.py')]
272
273
            # append import path to sys.path
274
            sys.path.append(path)
275
276
            # Go through all modules and create instances of each class found.
277
            for module_name in module_list:
278
                # import module
279
                mod = importlib.import_module('{0}'.format(module_name))
280
                importlib.reload(mod)
281
                # get all extractor class references defined in the module
282
                tmp_list = [m[1] for m in inspect.getmembers(mod, self.is_extractor_class)]
283
                # append to class_list
284
                class_list.extend(tmp_list)
285
        return class_list
286
287
    def __populate_method_dicts(self, instance_list):
288
        """
289
        Helper method to populate the dictionaries containing all references to callable extraction
290
        methods contained in extractor instances passed to this method.
291
292
        @param list instance_list: List containing instances of extractor classes
293
        """
294
        self._ungated_extraction_methods = dict()
295
        self._gated_extraction_methods = dict()
296
        for instance in instance_list:
297
            for method_name, method_ref in inspect.getmembers(instance, inspect.ismethod):
298
                if method_name.startswith('gated_'):
299
                    self._gated_extraction_methods[method_name[6:]] = method_ref
300
                elif method_name.startswith('ungated_'):
301
                    self._ungated_extraction_methods[method_name[8:]] = method_ref
302
        return
303
304
    def __populate_parameter_dict(self):
305
        """
306
        Helper method to populate the dictionary containing all possible keyword arguments from all
307
        extraction methods.
308
        """
309
        self._parameters = dict()
310
        for method in self._ungated_extraction_methods.values():
311
            self._parameters.update(self._get_extraction_method_kwargs(method=method))
312
        return
313
314
    @staticmethod
315
    def is_extractor_class(obj):
316
        """
317
        Helper method to check if an object is a valid extractor class.
318
319
        @param object obj: object to check
320
        @return bool: True if obj is a valid extractor class, False otherwise
321
        """
322
        if inspect.isclass(obj):
323
            return PulseExtractorBase in obj.__bases__ and len(obj.__bases__) == 1
324
        return False
325