Completed
Push — master ( 6ddd48...5bdf16 )
by Rich
16:04 queued 02:16
created

requires_h_depleted()   B

Complexity

Conditions 4

Size

Total Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 26
rs 8.5806
cc 4
1
#! /usr/bin/env python
2
#
3
# Copyright (C) 2016 Rich Lewis <[email protected]>
4
# License: 3-clause BSD
5
6
"""
7
# skchem.features.descriptors.decorators
8
9
Decorators for descriptors in scikit-chem.
10
"""
11
import inspect
12
from functools import wraps
13
from collections import OrderedDict, defaultdict
14
15
import numpy as np
16
from rdkit.Chem.rdmolops import GetDistanceMatrix, GetAdjacencyMatrix
17
18
19
def _add_cached_prop(func, prop):
20
21
    """ Annotate a function with the cached props. """
22
23
    if not hasattr(func, 'caches'):
24
        func.caches = []
25
    func.caches.append(prop)
26
27
28
def requires_h_depleted(func):
29
30
    """ Decorate a function that requires an h-depleted graph.
31
32
    This will check if the molecule argument is h-depleted, and will
33
    memoize a depleted version if it is not, and pass it to the func.
34
35
    Note:
36
        This decorator should be used first if in combination with dMat etc.
37
    """
38
39
    @wraps(func)
40
    def inner(mol, *args, **kwargs):
41
42
        # if is already h depleted
43
        if (mol.atoms.atomic_number == 1).sum() == 0:
44
            return func(mol, *args, **kwargs)
45
46
        if not hasattr(mol, '_h_depleted'):
47
            mol._h_depleted = mol.remove_hs()
48
49
        return func(mol._h_depleted, *args, **kwargs)
50
51
    _add_cached_prop(inner, '_h_depleted')
52
53
    return inner
54
55
56
def requires_h_filled(func):
57
58
    """ Decorate a function that requires a h-filled graph.
59
60
    This will check if the molecule argument is h-filled, and will
61
    memoize a filled version if it is not, and pass it to the func.
62
63
     Note:
64
        This decorator should be used first if in combination with dMat etc.
65
66
     """
67
68
    @wraps(func)
69
    def inner(mol, *args, **kwargs):
70
        # if is already h enriched
71
        if mol.atoms.n_total_hs.sum() == 0:
72
            return func(mol, *args, **kwargs)
73
74
        # if not, memoize the enriched one and pass it
75
        if not hasattr(mol, '_h_enriched'):
76
            mol._h_enriched = mol.add_hs()
77
78
        return func(mol._h_enriched, *args, **kwargs)
79
80
    _add_cached_prop(inner, '_h_enriched')
81
82
    return inner
83
84
85
class Cache(object):
86
87
    """ Function cache."""
88
89
    def __init__(self):
90
        self.cached = {}
91
92
    def __call__(self, func):
93
94
        """ Create a decorator to identify a function as returning a cached
95
        value.
96
97
        This can be used for objects  that are nontrivial to generate, but
98
        are used by many functions.
99
        """
100
101
        name = func.__name__
102
103
        # get the key word arguments and the default values of the function
104
        kwds = OrderedDict((k, v.default) for k, v in
105
                           inspect.signature(func).parameters.items()
106
                           if v.default != inspect._empty)
107
108
        @wraps(func)
109
        def inner(mol, *args, force=False, **kwargs):
0 ignored issues
show
introduced by
invalid syntax
Loading history...
110
111
            self.setup_cache(mol)
112
113
            # get the full set of keywords to use, including defaults
114
            kwds.update(kwargs)
115
            kw_to_save = tuple(sorted(kwds.items()))
116
117
            # call function if it hasn't already been called
118
            # with required arguments, or if told to.
119
            if force or name not in self.cached.keys() or \
120
                    kw_to_save not in mol.cache.get(name, {}).keys():
121
122
                res = func(mol, *args, **kwargs)
123
124
                # cache the value with the args used.
125
                mol.cache[name].update({kw_to_save: res})
126
127
            # return the cached value
128
            return mol.cache[name][kw_to_save]
129
130
        self.cached[name] = inner, tuple(kwds.keys())
131
132
        return inner
133
134
    def inject(self, *args_to_inject):
135
136
        """ Create a decorator that will inject cached values as arguments.
137
138
        Args:
139
            args (list<str>):
140
                A list of cachable requirements for this function.
141
        """
142
143
        def outer(func):
144
145
            # extract the defaults for the func
146
            kwds = OrderedDict((k, v.default) for k, v in
147
                               inspect.signature(func).parameters.items()
148
                               if v.default != inspect._empty)
149
150
            @wraps(func)
151
            def inner(mol, *args, **kwargs):
152
153
                # augment with the keywords from the function
154
                kwds.update(kwargs)
155
156
                self.setup_cache(mol)
157
158
                # look up cached values, or produce them if not.
159
                # inject the cached values
160
161
                args_supp = ()
162
163
                for arg in args_to_inject:
164
                    # lookup function to inject
165
                    inj_func, params = self.cached[arg.__name__]
166
167
                    # get the kwargs required
168
                    inj_kwargs = {param: kwds[param] for param in params
169
                                  if param in kwds.keys()}
170
171
                    # get a hashable representation of the kwargs
172
                    immut = tuple(sorted(inj_kwargs.items()))
173
174
                    # retrieve the cached result (or None if not yet cached)
175
                    res = mol.cache.get(arg.__name__, {}).get(immut, None)
176
177
                    # calculate and cache result
178
                    if res is None:
179
                        res = inj_func(mol, **inj_kwargs)
180
181
                    # add to injected args
182
                    args_supp += (res,)
183
184
                # put injected args at start of arg list
185
                args = args_supp + args
186
187
                return func(mol, *args, **kwargs)
188
189
            return inner
190
191
        return outer
192
193
    @staticmethod
194
    def setup_cache(mol):
195
196
        """ Set up a cache on e.g. a `Mol`. """
197
198
        if not hasattr(mol, 'cache'):
199
            mol.cache = defaultdict(dict)
200
201
    @staticmethod
202
    def teardown_cache(mol):
203
204
        """ Tear down a cache on e.g. a `Mol`. """
205
206
        if hasattr(mol, 'cache'):
207
            del mol.cache
208
209
cache = Cache()
210
211
212
def requires_dmat(func):
213
    """ Decorate a function that requires a distance matrix. """
214
    @wraps(func)
215
    def inner(mol, *args, **kwargs):
216
        if not hasattr(mol, '_dMat'):
217
            dmat = GetDistanceMatrix(mol)
218
            dmat[dmat > 100] = np.nan
219
            mol._dMat = dmat
220
        return func(mol, *args, **kwargs)
221
222
    _add_cached_prop(inner, '_dMat')
223
224
    return inner
225
226
227
def requires_amat(func):
228
    """ Decorate a function that requires an adjacency matrix. """
229
    @wraps(func)
230
    def inner(mol, *args, **kwargs):
231
        if not hasattr(mol, '_adjMat'):
232
            mol._adjMat = GetAdjacencyMatrix(mol)
233
        return func(mol, *args, **kwargs)
234
235
    _add_cached_prop(inner, 'amat')
236
237
    return inner
238
239
240
def requires_bo_amat(func):
241
    """ Decorate a function thate requires a bond order adjacency, matrix. """
242
    @wraps(func)
243
    def inner(mol, *args, **kwargs):
244
        if not hasattr(mol, '_bo_amat'):
245
            mol._bo_amat = GetAdjacencyMatrix(mol, useBO=1)
246
        return func(mol, *args, **kwargs)
247
248
    _add_cached_prop(inner, '_bo_amat')
249
250
    return inner
251
252
253
def requires_degrees(func):
254
    """ Decorate a function that requires a degree vector. """
255
    @wraps(func)
256
    def inner(mol, *args, **kwargs):
257
        if not hasattr(mol, '_degrees'):
258
            mol._degrees = np.array([atom.degree for atom in mol.atoms])
259
        return func(mol, *args, **kwargs)
260
261
    _add_cached_prop(inner, '_degrees')
262
263
    return inner