Completed
Pull Request — master (#384)
by
unknown
01:25
created

SamplingBase   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 44
Duplicated Lines 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
c 4
b 0
f 0
dl 0
loc 44
rs 10
wmc 12

4 Methods

Rating   Name   Duplication   Size   Complexity  
A __repr__() 0 8 3
A __str__() 0 8 3
A __eq__() 0 12 4
A get_dict_representation() 0 7 2
1
# -*- coding: utf-8 -*-
2
3
"""
4
This file contains the Qudi file with all available sampling functions.
5
6
Qudi is free software: you can redistribute it and/or modify
7
it under the terms of the GNU General Public License as published by
8
the Free Software Foundation, either version 3 of the License, or
9
(at your option) any later version.
10
11
Qudi is distributed in the hope that it will be useful,
12
but WITHOUT ANY WARRANTY; without even the implied warranty of
13
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
GNU General Public License for more details.
15
16
You should have received a copy of the GNU General Public License
17
along with Qudi. If not, see <http://www.gnu.org/licenses/>.
18
19
Copyright (c) the Qudi Developers. See the COPYRIGHT.txt file at the
20
top-level directory of this distribution and at <https://github.com/Ulm-IQO/qudi/>
21
"""
22
23
import os
24
import importlib
25
import sys
26
import inspect
27
import copy
28
from collections import OrderedDict
29
30
31
class SamplingBase:
32
    """
33
    Base class for all sampling functions
34
    """
35
    params = OrderedDict()
36
37
    def __repr__(self):
38
        kwargs = []
39
        for param, def_dict in self.params.items():
40
            if def_dict['type'] is str:
41
                kwargs.append('{0}=\'{1}\''.format(param, getattr(self, param)))
42
            else:
43
                kwargs.append('{0}={1}'.format(param, getattr(self, param)))
44
        return '{0}({1})'.format(type(self).__name__, ', '.join(kwargs))
45
46
    def __str__(self):
47
        kwargs = ('='.join((param, str(getattr(self, param)))) for param in self.params)
48
        return_str = 'Sampling Function: "{0}"\nParameters:'.format(type(self).__name__)
49
        if len(self.params) < 1:
50
            return_str += ' None'
51
        else:
52
            return_str += '\n\t' + '\n\t'.join(kwargs)
53
        return return_str
54
55
    def __eq__(self, other):
56
        if not isinstance(other, SamplingBase):
57
            return False
58
        hash_list = [type(self).__name__]
59
        for param in self.params:
60
            hash_list.append(getattr(self, param))
61
        hash_self = hash(tuple(hash_list))
62
        hash_list = [type(other).__name__]
63
        for param in other.params:
64
            hash_list.append(getattr(other, param))
65
        hash_other = hash(tuple(hash_list))
66
        return hash_self == hash_other
67
68
    def get_dict_representation(self):
69
        dict_repr = dict()
70
        dict_repr['name'] = type(self).__name__
71
        dict_repr['params'] = dict()
72
        for param in self.params:
73
            dict_repr['params'][param] = getattr(self, param)
74
        return dict_repr
75
76
77
class SamplingFunctions:
78
    """
79
80
    """
81
    parameters = dict()
82
83
    @classmethod
84
    def import_sampling_functions(cls, path_list):
85
        param_dict = dict()
86
        for path in path_list:
87
            if not os.path.exists(path):
88
                continue
89
            # Get all python modules to import from.
90
            module_list = [name[:-3] for name in os.listdir(path) if
91
                           os.path.isfile(os.path.join(path, name)) and name.endswith('.py')]
92
93
            # append import path to sys.path
94
            if path not in sys.path:
95
                sys.path.append(path)
96
97
            # Go through all modules and get all sampling function classes.
98
            for module_name in module_list:
99
                # import module
100
                mod = importlib.import_module('{0}'.format(module_name))
101
                # Delete all remaining references to sampling functions.
102
                # This is neccessary if you have removed a sampling function class.
103
                for attr in cls.parameters:
104
                    if hasattr(mod, attr):
105
                        delattr(mod, attr)
106
                importlib.reload(mod)
107
                # get all sampling function class references defined in the module
108
                for name, ref in inspect.getmembers(mod, cls.is_sampling_function_class):
109
                    setattr(cls, name, cls.__get_sf_method(ref))
110
                    param_dict[name] = copy.deepcopy(ref.params)
111
112
        # Remove old sampling functions
113
        for func in cls.parameters:
114
            if func not in param_dict:
115
                delattr(cls, func)
116
117
        cls.parameters = param_dict
118
        return
119
120
    @staticmethod
121
    def __get_sf_method(sf_ref):
122
        return lambda *args, **kwargs: sf_ref(*args, **kwargs)
123
124
    @staticmethod
125
    def is_sampling_function_class(obj):
126
        """
127
        Helper method to check if an object is a valid sampling function class.
128
129
        @param object obj: object to check
130
        @return bool: True if obj is a valid sampling function class, False otherwise
131
        """
132
        if inspect.isclass(obj):
133
            return SamplingBase in obj.__bases__ and len(obj.__bases__) == 1
134
        return False
135