Completed
Pull Request — master (#1)
by Valentin
02:24
created

ScopeList.is_match()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
import ast
2
from contextlib import contextmanager
3
4
from .analysis import get_id
5
6
def add_scope_context(node):
7
    """Provide to scope context to all nodes"""
8
    return ScopeTransformer().visit(node)
9
10
11
class ScopeMixin(object):
12
    """
13
    Adds a scope property with the current scope (function, module)
14
    a node is part of.
15
    """
16
    scopes = []
17
18
    @contextmanager
19
    def enter_scope(self, node):
20
        if self._is_scopable_node(node):
21
            self.scopes.append(node)
22
            yield
23
            self.scopes.pop()
24
        else:
25
            yield
26
27
    @property
28
    def scope(self):
29
        try:
30
            return self.scopes[-1]
31
        except IndexError:
32
            return None
33
34
    def _is_scopable_node(self, node):
35
        scopes = [ast.Module, ast.FunctionDef, ast.For, ast.If, ast.With]
36
        return len([s for s in scopes if isinstance(node, s)]) > 0
37
38
39
class ScopeList(list):
40
    """
41
    Wraps around list of scopes and provides find method for finding
42
    the definition of a variable
43
    """
44
    def find(self, lookup):
45
        """Find definition of variable lookup."""
46
        def find_definition(scope, var_attr="vars"):
47
            for var in getattr(scope, var_attr):
48
                if get_id(var) == lookup:
49
                    return var
50
51
        for scope in self:
52
            defn = find_definition(scope)
53
            if not defn and hasattr(scope, "body_vars"):
54
                defn = find_definition(scope, "body_vars")
55
            if not defn and hasattr(scope, "orelse_vars"):
56
                defn = find_definition(scope, "orelse_vars")
57
            if defn:
58
                return defn
59
60
    def find_import(self, lookup):
61
        for scope in reversed(self):
62
            if hasattr(scope, "imports"):
63
                for imp in scope.imports:
64
                    if imp.name == lookup:
65
                        return imp
66
67
68
class ScopeTransformer(ast.NodeTransformer, ScopeMixin):
69
    """
70
    Adds a scope attribute to each node.
71
    The scope contains the current scope (function, module, for loop)
72
    a node is part of.
73
    """
74
75
    def visit(self, node):
76
        with self.enter_scope(node):
77
            node.scopes = ScopeList(self.scopes)
78
            return super(ScopeTransformer, self).visit(node)
79