Completed
Pull Request — master (#384)
by
unknown
02:57
created

SamplingBase   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 13
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
c 0
b 0
f 0
dl 0
loc 13
rs 10
wmc 2

1 Method

Rating   Name   Duplication   Size   Complexity  
A get_dict_representation() 0 7 2
1
# -*- coding: utf-8 -*-
2
3
"""
4
This file contains the Qudi file with all available sampling functions.
5
6
Qudi is free software: you can redistribute it and/or modify
7
it under the terms of the GNU General Public License as published by
8
the Free Software Foundation, either version 3 of the License, or
9
(at your option) any later version.
10
11
Qudi is distributed in the hope that it will be useful,
12
but WITHOUT ANY WARRANTY; without even the implied warranty of
13
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
GNU General Public License for more details.
15
16
You should have received a copy of the GNU General Public License
17
along with Qudi. If not, see <http://www.gnu.org/licenses/>.
18
19
Copyright (c) the Qudi Developers. See the COPYRIGHT.txt file at the
20
top-level directory of this distribution and at <https://github.com/Ulm-IQO/qudi/>
21
"""
22
23
import os
24
import importlib
25
import sys
26
import inspect
27
import copy
28
from collections import OrderedDict
29
30
31
class SamplingBase:
32
    """
33
    Base class for all sampling functions
34
    """
35
    params = OrderedDict()
36
37
    def get_dict_representation(self):
38
        dict_repr = dict()
39
        dict_repr['name'] = self.__class__.__name__
40
        dict_repr['params'] = dict()
41
        for param in self.params:
42
            dict_repr['params'][param] = getattr(self, param)
43
        return dict_repr
44
45
46
class SamplingFunctions:
47
    """
48
49
    """
50
    parameters = dict()
51
52
    @classmethod
53
    def import_sampling_functions(cls, path_list):
54
        param_dict = dict()
55
        for path in path_list:
56
            if not os.path.exists(path):
57
                continue
58
            # Get all python modules to import from.
59
            module_list = [name[:-3] for name in os.listdir(path) if
60
                           os.path.isfile(os.path.join(path, name)) and name.endswith('.py')]
61
62
            # append import path to sys.path
63
            if path not in sys.path:
64
                sys.path.append(path)
65
66
            # Go through all modules and get all sampling function classes.
67
            for module_name in module_list:
68
                # import module
69
                mod = importlib.import_module('{0}'.format(module_name))
70
                # Delete all remaining references to sampling functions.
71
                # This is neccessary if you have removed a sampling function class.
72
                for attr in cls.parameters:
73
                    if hasattr(mod, attr):
74
                        delattr(mod, attr)
75
                importlib.reload(mod)
76
                # get all sampling function class references defined in the module
77
                for name, ref in inspect.getmembers(mod, cls.is_sampling_function_class):
78
                    setattr(cls, name, cls.__get_sf_method(ref))
79
                    param_dict[name] = copy.deepcopy(ref.params)
80
81
        # Remove old sampling functions
82
        for func in cls.parameters:
83
            if func not in param_dict:
84
                delattr(cls, func)
85
86
        cls.parameters = param_dict
87
        return
88
89
    @staticmethod
90
    def __get_sf_method(sf_ref):
91
        return lambda *args, **kwargs: sf_ref(*args, **kwargs)
92
93
    @staticmethod
94
    def is_sampling_function_class(obj):
95
        """
96
        Helper method to check if an object is a valid sampling function class.
97
98
        @param object obj: object to check
99
        @return bool: True if obj is a valid sampling function class, False otherwise
100
        """
101
        if inspect.isclass(obj):
102
            return SamplingBase in obj.__bases__ and len(obj.__bases__) == 1
103
        return False
104