Completed
Push — master ( 98b503...0dd2b7 )
by Rich
24:39 queued 09:33
created

requires_bo_amat()   A

Complexity

Conditions 3

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 11
rs 9.4285
cc 3
1
#! /usr/bin/env python
2
#
3
# Copyright (C) 2016 Rich Lewis <[email protected]>
4
# License: 3-clause BSD
5
6
"""
7
# skchem.features.descriptors.decorators
8
9
Decorators for descriptors in scikit-chem.
10
"""
11
import inspect
12
from functools import wraps
13
from collections import OrderedDict, defaultdict
14
15
import numpy as np
16
from rdkit.Chem.rdmolops import GetDistanceMatrix, GetAdjacencyMatrix
17
18
19
def requires_h_depleted(func):
20
21
    """ Decorate a function that requires an h-depleted graph.
22
23
    This will check if the molecule argument is h-depleted, and will
24
    memoize a depleted version if it is not, and pass it to the func.
25
26
27
    """
28
29
    @wraps(func)
30
    def inner(mol, *args, **kwargs):
31
32
        # if is already h depleted
33
        if (mol.atoms.atomic_number == 1).sum() == 0:
34
            return func(mol, *args, **kwargs)
35
36
        if not hasattr(mol, '_h_depleted'):
37
            mol._h_depleted = mol.remove_hs()
38
39
        return func(mol._h_depleted, *args, **kwargs)
40
41
    return inner
42
43
44
def requires_h_filled(func):
45
46
    """ Decorate a function that requires a h-filled graph.
47
48
    This will check if the molecule argument is h-filled, and will
49
    memoize a filled version if it is not, and pass it to the func.
50
51
     Note:
52
        This decorator should be used first if in combination with dMat etc.
53
54
     """
55
56
    @wraps(func)
57
    def inner(mol, *args, **kwargs):
58
        # if is already h enriched
59
        if mol.atoms.n_total_hs.sum() == 0:
60
            return func(mol, *args, **kwargs)
61
62
        # if not, memoize the enriched one and pass it
63
        if not hasattr(mol, '_h_enriched'):
64
            mol._h_enriched = mol.add_hs()
65
66
        return func(mol._h_enriched, *args, **kwargs)
67
68
    return inner
69
70
71
class Cache(object):
72
73
    """ Function cache."""
74
75
    def __init__(self):
76
        self.cached = {}
77
78
    def __call__(self, func):
79
80
        """ Create a decorator to identify a function as returning a cached
81
        value.
82
83
        This can be used for objects  that are nontrivial to generate, but
84
        are used by many functions.
85
        """
86
87
        name = func.__name__
88
89
        # get the key word arguments and the default values of the function
90
        kwds = OrderedDict((k, v.default) for k, v in
91
                           inspect.signature(func).parameters.items()
92
                           if v.default != inspect._empty)
93
94
        @wraps(func)
95
        def inner(mol, *args, force=False, **kwargs):
0 ignored issues
show
introduced by
invalid syntax
Loading history...
96
97
            self.setup_cache(mol)
98
99
            # get the full set of keywords to use, including defaults
100
            kwds.update(kwargs)
101
            kw_to_save = tuple(sorted(kwds.items()))
102
103
            # call function if it hasn't already been called
104
            # with required arguments, or if told to.
105
            if force or name not in self.cached.keys() or \
106
                    kw_to_save not in mol.cache.get(name, {}).keys():
107
108
                res = func(mol, *args, **kwargs)
109
110
                # cache the value with the args used.
111
                mol.cache[name].update({kw_to_save: res})
112
113
            # return the cached value
114
            return mol.cache[name][kw_to_save]
115
116
        self.cached[name] = inner, tuple(kwds.keys())
117
118
        return inner
119
120
    def inject(self, *args_to_inject):
121
122
        """ Create a decorator that will inject cached values as arguments.
123
124
        Args:
125
            args (list<str>):
126
                A list of cachable requirements for this function.
127
        """
128
129
        def outer(func):
130
131
            # extract the defaults for the func
132
            kwds = OrderedDict((k, v.default) for k, v in
133
                               inspect.signature(func).parameters.items()
134
                               if v.default != inspect._empty)
135
136
            @wraps(func)
137
            def inner(mol, *args, **kwargs):
138
139
                # augment with the keywords from the function
140
                kwds.update(kwargs)
141
142
                self.setup_cache(mol)
143
144
                # look up cached values, or produce them if not.
145
                # inject the cached values
146
147
                args_supp = ()
148
149
                for arg in args_to_inject:
150
                    # lookup function to inject
151
                    inj_func, params = self.cached[arg.__name__]
152
153
                    # get the kwargs required
154
                    inj_kwargs = {param: kwds[param] for param in params
155
                                  if param in kwds.keys()}
156
157
                    # get a hashable representation of the kwargs
158
                    immut = tuple(sorted(inj_kwargs.items()))
159
160
                    # retrieve the cached result (or None if not yet cached)
161
                    res = mol.cache.get(arg.__name__, {}).get(immut, None)
162
163
                    # calculate and cache result
164
                    if res is None:
165
                        res = inj_func(mol, **inj_kwargs)
166
167
                    # add to injected args
168
                    args_supp += (res,)
169
170
                # put injected args at start of arg list
171
                args = args_supp + args
172
173
                return func(mol, *args, **kwargs)
174
175
            return inner
176
177
        return outer
178
179
    @staticmethod
180
    def setup_cache(mol):
181
182
        """ Set up a cache on e.g. a `Mol`. """
183
184
        if not hasattr(mol, 'cache'):
185
            mol.cache = defaultdict(dict)
186
187
    @staticmethod
188
    def teardown_cache(mol):
189
190
        """ Tear down a cache on e.g. a `Mol`. """
191
192
        if hasattr(mol, 'cache'):
193
            del mol.cache
194
195
cache = Cache()
196