Completed
Push — master ( 5b00a3...5fddd3 )
by Rich
14:42
created

requires_h_depleted()   B

Complexity

Conditions 4

Size

Total Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 1 Features 0
Metric Value
c 1
b 1
f 0
dl 0
loc 26
rs 8.5806
cc 4
1
#! /usr/bin/env python
2
#
3
# Copyright (C) 2016 Rich Lewis <[email protected]>
4
# License: 3-clause BSD
5
6
"""
7
# skchem.features.descriptors.decorators
8
9
Decorators for descriptors in scikit-chem.
10
"""
11
12
from functools import wraps
13
14
import numpy as np
15
from rdkit.Chem.rdmolops import GetDistanceMatrix, GetAdjacencyMatrix
16
17
18
def _add_cached_prop(func, prop):
19
20
    """ Annotate a function with the cached props. """
21
22
    if not hasattr(func, 'caches'):
23
        func.caches = []
24
    func.caches.append(prop)
25
26
27
def requires_h_depleted(func):
28
29
    """ Decorate a function that requires an h-depleted graph.
30
31
    This will check if the molecule argument is h-depleted, and will
32
    memoize a depleted version if it is not, and pass it to the func.
33
34
    Note:
35
        This decorator should be used first if in combination with dMat etc.
36
    """
37
38
    @wraps(func)
39
    def inner(mol, *args, **kwargs):
40
41
        # if is already h depleted
42
        if (mol.atoms.atomic_number == 1).sum() == 0:
43
            return func(mol, *args, **kwargs)
44
45
        if not hasattr(mol, '_h_depleted'):
46
            mol._h_depleted = mol.remove_hs()
47
48
        return func(mol._h_depleted, *args, **kwargs)
49
50
    _add_cached_prop(inner, '_h_depleted')
51
52
    return inner
53
54
55
def requires_h_filled(func):
56
57
    """ Decorate a function that requires a h-filled graph.
58
59
    This will check if the molecule argument is h-filled, and will
60
    memoize a filled version if it is not, and pass it to the func.
61
62
     Note:
63
        This decorator should be used first if in combination with dMat etc.
64
65
     """
66
67
    @wraps(func)
68
    def inner(mol, *args, **kwargs):
69
        # if is already h enriched
70
        if mol.atoms.n_total_hs.sum() == 0:
71
            return func(mol, *args, **kwargs)
72
73
        # if not, memoize the enriched one and pass it
74
        if not hasattr(mol, '_h_enriched'):
75
            mol._h_enriched = mol.add_hs()
76
77
        return func(mol._h_enriched, *args, **kwargs)
78
79
    _add_cached_prop(inner, '_h_enriched')
80
81
    return inner
82
83
84
class Cache(object):
85
86
    """ Function cache."""
87
88
    def __init__(self):
89
        self.cached = {}
90
91
    def __call__(self, func):
92
93
        """ Create a decorator to identify a function as returning a cached
94
        value.
95
96
        This can be used for objects  that are nontrivial to generate, but
97
        are used by many functions.
98
99
        Args:
100
            name (str):
101
                The name to cache the values under
102
        """
103
104
        name = func.__name__
105
106
        @wraps(func)
107
        def inner(mol, *args, force=False, **kwargs):
0 ignored issues
show
introduced by
invalid syntax
Loading history...
108
109
            self.setup_cache(mol)
110
111
            # call function if it hasn't already been called.
112
            if force or name not in mol.cache.keys():
113
                res = func(mol, *args, **kwargs)
114
                mol.cache[name] = res
115
116
            # return the cached value
117
            return mol.cache[name]
118
119
        self.cached[name] = inner
120
        return inner
121
122
    def retrieve(self, *args_to_inject):
123
124
        """ Create a decorator that will inject cached values as arguments.
125
126
        Args:
127
            args (list<str>):
128
                A list of cachable requirements for this function.
129
        """
130
131
        def outer(func):
132
133
            @wraps(func)
134
            def inner(mol, *args, **kwargs):
135
136
                self.setup_cache(mol)
137
138
                # look up cached values, or produce them if not.
139
                # inject the cached values
140
                args = tuple(mol.cache.get(arg.__name__,
141
                                           self.cached[arg.__name__](mol))
142
                             for arg in args_to_inject) + args
143
144
                return func(mol, *args, **kwargs)
145
            return inner
146
        return outer
147
148
    @staticmethod
149
    def setup_cache(mol):
150
        if not hasattr(mol, 'cache'):
151
            mol.cache = {}
152
153
    @staticmethod
154
    def teardown_cache(mol):
155
        if hasattr(mol, 'cache'):
156
            del mol.cache
157
158
cache = Cache()
159
160
161
def requires_dmat(func):
162
    """ Decorate a function that requires a distance matrix. """
163
    @wraps(func)
164
    def inner(mol, *args, **kwargs):
165
        if not hasattr(mol, '_dMat'):
166
            dmat = GetDistanceMatrix(mol)
167
            dmat[dmat > 100] = np.nan
168
            mol._dMat = dmat
169
        return func(mol, *args, **kwargs)
170
171
    _add_cached_prop(inner, '_dMat')
172
173
    return inner
174
175
176
def requires_amat(func):
177
    """ Decorate a function that requires an adjacency matrix. """
178
    @wraps(func)
179
    def inner(mol, *args, **kwargs):
180
        if not hasattr(mol, '_adjMat'):
181
            mol._adjMat = GetAdjacencyMatrix(mol)
182
        return func(mol, *args, **kwargs)
183
184
    _add_cached_prop(inner, 'amat')
185
186
    return inner
187
188
189
def requires_bo_amat(func):
190
    """ Decorate a function thate requires a bond order adjacency, matrix. """
191
    @wraps(func)
192
    def inner(mol, *args, **kwargs):
193
        if not hasattr(mol, '_bo_amat'):
194
            mol._bo_amat = GetAdjacencyMatrix(mol, useBO=1)
195
        return func(mol, *args, **kwargs)
196
197
    _add_cached_prop(inner, '_bo_amat')
198
199
    return inner
200
201
202
def requires_degrees(func):
203
    """ Decorate a function that requires a degree vector. """
204
    @wraps(func)
205
    def inner(mol, *args, **kwargs):
206
        if not hasattr(mol, '_degrees'):
207
            mol._degrees = np.array([atom.degree for atom in mol.atoms])
208
        return func(mol, *args, **kwargs)
209
210
    _add_cached_prop(inner, '_degrees')
211
212
    return inner