injectify.injectors.TailInjector.inject()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
"""This module contains the model objects that power Injectify."""
2
3
import ast
4
import inspect
5
from abc import ABC, abstractmethod
6
from collections import defaultdict
7
8
import astunparse
9
10
from .inspect_mate import extract_wrapped, getsource
11
from .structures import listify
12
from .utils import (
13
    parse_object,
14
    tryattrs,
15
    get_defining_class,
16
    caninject,
17
)
18
19
20
def _count_visit(f):
21
    def wrapper(self, node):
22
        if not hasattr(self, '_visit_counter'):
23
            self._visit_counter = defaultdict(int)
24
25
        if f.__name__ not in self._visit_counter:
26
            # B/c increment happens before function call,
27
            # initialize at -1
28
            self._visit_counter[f.__name__] = -1
29
30
        self._visit_counter[f.__name__] += 1
31
        r = f(self, node, self._visit_counter[f.__name__])
32
        return r
33
34
    return wrapper
35
36
37
class BaseInjector(ABC, ast.NodeTransformer):
38
    """An abstract class that identifies an injection point.
39
40
    Args:
41
        save_state: Whether or not the target object should allow multiple
42
            injections.
43
    """
44
45
    def __init__(self, save_state=True):
46
        self.save_state = save_state
47
48
    def prepare(self, target, handler):
49
        """Prepares the injector with the given parameters."""
50
        self.prepare_target(target)
51
        self.prepare_handler(handler)
52
53
    def prepare_target(self, target):
54
        """Prepares the given target object."""
55
        if caninject(target):
56
            raise TypeError('cannot inject to type {!r}', type(target))
57
58
        wrapped = extract_wrapped(target)
59
        self.target = wrapped or target
60
61
    def prepare_handler(self, handler):
62
        """Prepares the given handler function."""
63
        node = parse_object(handler)
64
        self.handler = node.body[0].body
65
66
    def visit_target(self):
67
        """Visit the AST node of the target object."""
68
        return self.visit(parse_object(self.target))
69
70
    def is_target_module(self):
71
        """Check whether the target object is a module."""
72
        return inspect.ismodule(self.target)
73
74
    def compile(self, tree):
75
        """Recompile the target object with the handler."""
76
77
        def inject_code(f):
78
            # Used to allow injection multiple times in a
79
            # single object, b/c inject.findsource() reads
80
            # the actual source file
81
            f.__inject_code__ = code if self.save_state else target_src
82
83
        target_name = self.target.__name__
84
        target_file = inspect.getfile(self.target)
85
        target_src = getsource(self.target)
86
87
        # Find the ast node with the same name as our target object and get the
88
        # source code
89
        node = next(x for x in tree.body if getattr(x, 'name', None) == target_name)
90
        if hasattr(node, 'decorator_list'):
91
            # Don't want to compile the decorators
92
            node.decorator_list = []
93
        code = astunparse.unparse(node)
94
95
        # Compile the new object
96
        _locals = {}
97
        exec(compile(code, target_file, 'exec'), _locals)
98
        compiled_obj = _locals[target_name]
99
100
        # Replace the old code with the new code
101
        try:
102
            # If function has code object, simply replace it
103
            self.target.__code__ = compiled_obj.__code__
104
            inject_code(self.target)
105
        except AttributeError:
106
            # Attempt to the class that the function is defined in
107
            meth_mod = get_defining_class(self.target)
108
            if not meth_mod:
109
                # If function is not defined in a class, or the target is not a function
110
                meth_mod = inspect.getmodule(self.target)
111
112
            inject_code(compiled_obj)
113
            setattr(meth_mod, target_name, compiled_obj)
114
115
    @abstractmethod
116
    def inject(self, node):
117
        """Abstract method that merges the handler into the target."""
118
        pass
119
120
121
class HeadInjector(BaseInjector):
122
    """An injector that injects code at the top of the object.
123
124
    **Usage**
125
        .. code-block::
126
127
            from injectify import inject, HeadInjector
128
129
            def file_write(filename, data):
130
                with open(filename, 'w') as f:
131
                    f.write(data)
132
133
            @inject(target=target, injector=HeadInjector())
134
            def handler():
135
                data = 'injected'
136
137
        After the injection happens, the function ``file_write`` has code that is
138
        equivalent to
139
140
        .. code-block::
141
142
            def file_write(filename, data):
143
                data = 'injected'
144
                with open(filename, 'w') as f:
145
                    f.write(data)
146
    """
147
148
    def visit_Module(self, node):
149
        """Visit a ``Module`` node.
150
151
        If the target object is a module then inject the handler in this node,
152
        else keep traversing. This is because the root of the AST will be this
153
        node for code parsed using the `exec` mode.
154
        """
155
        if self.is_target_module():
156
            return self._visit(node)
157
        self.generic_visit(node)
158
        return node
159
160
    def visit_ClassDef(self, node):
161
        """Visit a ``ClassDef`` node."""
162
        return self._visit(node)
163
164
    def visit_FunctionDef(self, node):
165
        """Visit a ``FunctionDef`` node."""
166
        return self._visit(node)
167
168
    visit_AsyncFunctionDef = visit_FunctionDef
169
170
    def _visit(self, node):
171
        return ast.fix_missing_locations(self.inject(node))
172
173
    def inject(self, node):
174
        """Inject the handler at the top of the target object."""
175
        node.body.insert(0, self.handler)
176
        return node
177
178
179
class TailInjector(BaseInjector):
180
    """An injector that injects code at the bottom of the object.
181
182
    **Usage**
183
        .. code-block:: python
184
185
            import os.path
186
            from injectify import inject, TailInjector
187
188
            def file_read(filename):
189
                if os.path.exists(filename):
190
                    with open(filename_, 'r') as f:
191
                        return f.read()
192
193
            @inject(target=target, injector=TailInjector())
194
            def handler():
195
                raise FileNotFoundError('File does not exist')
196
197
        After the injection happens, the function ``file_open`` has code that is
198
        equivalent to
199
200
        .. code-block::
201
202
            def file_read(filename):
203
                if os.path.exists(filename):
204
                    with open(filename, 'r') as f:
205
                        return f.read()
206
                raise FileNotFoundError('File does not exist')
207
    """
208
209
    def visit_Module(self, node):
210
        """Visit a ``Module`` node.
211
212
        If the target object is a module then inject the handler in this node,
213
        else keep traversing. This is because the root of the AST will be this
214
        node for code parsed using the `exec` mode.
215
        """
216
        if self.is_target_module():
217
            return self._visit(node)
218
        self.generic_visit(node)
219
        return node
220
221
    def visit_ClassDef(self, node):
222
        """Visit a ``ClassDef`` node."""
223
        return self._visit(node)
224
225
    def visit_FunctionDef(self, node):
226
        """Visit a ``FunctionDef`` node."""
227
        return self._visit(node)
228
229
    visit_AsyncFunctionDef = visit_FunctionDef
230
231
    def _visit(self, node):
232
        return ast.fix_missing_locations(self.inject(node))
233
234
    def inject(self, node):
235
        """Inject the handler at the bottom of the target object."""
236
        node.body.append(self.handler)
237
        return node
238
239
240
class ReturnInjector(BaseInjector):
241
    """An injector that injects code before a return statement.
242
243
    Note: The ``ReturnInjector`` can only be used when the target is a
244
    `function` or `method`.
245
246
    Args:
247
        ordinal: Optional zero-based index to choose specific point of injection.
248
            Multiple indices can be given in the form of a list.
249
250
    **Usage**
251
        .. code-block:: python
252
253
            import statistics
254
            from injectify import inject, ReturnInjector
255
256
            def stat(operation, seq):
257
                if operation == 'mean':
258
                    return statistics.mean(seq)
259
                elif operation == 'median':
260
                    return statistics.median(seq)
261
                elif operation == 'mode':
262
                    return staistics.mode(seq)
263
264
            @inject(target=target, injector=ReturnInjector(ordinal=[1,2]))
265
            def handler():
266
                seq = list(seq)
267
                seq.append(10)
268
269
        After the injection happens, the function ``stat`` has code that is
270
        equivalent to
271
272
        .. code-block::
273
274
            def stat(operation, seq):
275
                if operation == 'mean':
276
                    return statistics.mean(seq)
277
                elif operation == 'median':
278
                    seq = list(seq)
279
                    seq.append(10)
280
                    return statistics.median(seq)
281
                elif operation == 'mode':
282
                    seq = list(seq)
283
                    seq.append(10)
284
                    return staistics.mode(seq)
285
    """
286
287
    def __init__(self, ordinal=None, *args, **kwargs):
288
        self.ordinal = listify(ordinal)
289
290
        super().__init__(*args, **kwargs)
291
292
    @_count_visit
293
    def visit_Return(self, node, visit_count):
294
        """Visit a ``Return`` node."""
295
        if not self.ordinal or visit_count in self.ordinal:
296
            return ast.copy_location(self.inject(node), node)
297
        self.generic_visit(node)
298
        return node
299
300
    def inject(self, node):
301
        """Inject the handler before each return statement in the target object."""
302
        return ast.Module(body=[self.handler, node])
303
304
305
class FieldInjector(BaseInjector):
306
    """An injector that injects code at a field's assignment.
307
308
    Args:
309
        field: The field to inject at.
310
        ordinal: Zero-based index to choose specific point of injection.
311
        insert: Where to insert the handler's code relative to the target.
312
            Options include 'before' and 'after'.
313
314
    **Usage**
315
        .. code-block:: python
316
317
            from injectify import inject, FieldInjector
318
319
            def get_rank(year):
320
                if year == 1:
321
                    rank = 'Freshman'
322
                elif year == 2:
323
                    rank = 'Sophomore'
324
                elif year == 3:
325
                    rank = 'Junior'
326
                else:
327
                    rank = 'Senor'
328
                return rank
329
330
            @inject(target=target,
331
                    injector=FieldInjector('rank', ordinal=3, insert='after'))
332
            def handler():
333
                rank = 'Senior'
334
335
        After the injection happens, the function ``stat`` has code that is
336
        equivalent to
337
338
        .. code-block::
339
340
            def get_rank(year):
341
                if year == 1:
342
                    rank = 'Freshman'
343
                elif year == 2:
344
                    rank = 'Sophomore'
345
                elif year == 3:
346
                    rank = 'Junior'
347
                else:
348
                    rank = 'Senor'
349
                    rank = 'Senior'
350
                return rank
351
    """
352
353
    def __init__(self, field, ordinal=None, insert=None, *args, **kwargs):
354
        super().__init__(*args, **kwargs)
355
356
        self.field = field
357
        self.ordinal = listify(ordinal)
358
        self.insert = insert
359
        self._field_counter = defaultdict(int)
360
361
    def visit_Assign(self, node):
362
        """Visit an ``Assign`` node."""
363
        field = self.field or self.target
364
365
        if any(field == tryattrs(t, 'id', 'attr') for t in node.targets):
366
            field_count = self._field_counter[field]
367
            self._field_counter[field] += 1
368
            if not self.ordinal or field_count in self.ordinal:
369
                return ast.copy_location(self.inject(node), node)
370
        self.generic_visit(node)
371
        return node
372
373
    def inject(self, node):
374
        """Inject the handler at the assignment of the given field in the
375
        target object."""
376
        if self.insert == 'after':
377
            return ast.Module(body=[node, self.handler])
378
        else:
379
            return ast.Module(body=[self.handler, node])
380
381
382
class NestedInjector(BaseInjector):
383
    """An injector that injects code in a nested function.
384
385
    Note: The ``NestedInjector`` can only be used when the target is a
386
    `function` or `method`.
387
388
    Args:
389
        nested: Name of the nested function.
390
        injector: Injector to use in the nested function.
391
392
    **Usage**
393
        .. code-block:: python
394
395
            from time import time
396
            from injectify import inject, FieldInjector
397
398
            def timing(f):
399
                def wrapper(*args, **kwargs):
400
                    ts = time()
401
                    result = f(*args, **kwargs)
402
                    te = time()
403
                    return result
404
405
            @inject(target=target,
406
                    injector=NestedInjector('wrapper', ReturnInjector()))
407
            def handler():
408
                print('func:{!r} args:[{!r}, {!r}] took: {:2.f} sec'.format(
409
                        f.__name__, args, kwargs, te-ts))
410
411
        After the injection happens, the function ``stat`` has code that is
412
        equivalent to
413
414
        .. code-block::
415
416
            def timing(f):
417
                def wrapper(*args, **kwargs):
418
                    ts = time()
419
                    result = f(*args, **kwargs)
420
                    te = time()
421
                    print('func:{!r} args:[{!r}, {!r}] took: {:2.f} sec'.format(
422
                        f.__name__, args, kwargs, te-ts))
423
                    return result
424
    """
425
426
    def __init__(self, nested, injector, *args, **kwargs):
427
        super().__init__(*args, **kwargs)
428
429
        self.nested = nested
430
        self.injector = injector
431
432
    def prepare(self, target, handler):
433
        """Prepares the injector and the nested injector with the given
434
        parameters."""
435
        super().prepare(target, handler)
436
        self.injector.prepare(target, handler)
437
438
    def visit_FunctionDef(self, node):
439
        """Visit a ``FunctionDef`` node."""
440
        if node.name == self.nested:
441
            return ast.fix_missing_locations(self.inject(node))
442
        self.generic_visit(node)
443
        return node
444
445
    visit_AsyncFunctionDef = visit_FunctionDef
446
447
    def inject(self, node):
448
        """Inject the handler into the nested function with the given injector."""
449
        return self.injector.inject(node)
450