ScopeList.find_definition()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 4
rs 10
1
import ast
2
from contextlib import contextmanager
3
4
5
def add_scope_context(node):
6
    """Provide to scope context to all nodes"""
7
    return ScopeTransformer().visit(node)
8
9
10
class ScopeMixin(object):
11
    """
12
    Adds a scope property with the current scope (function, module)
13
    a node is part of.
14
    """
15
    scopes = []
16
17
    @contextmanager
18
    def enter_scope(self, node):
19
        if self._is_scopable_node(node):
20
            self.scopes.append(node)
21
            yield
22
            self.scopes.pop()
23
        else:
24
            yield
25
26
    @property
27
    def scope(self):
28
        try:
29
            return self.scopes[-1]
30
        except IndexError:
31
            return None
32
33
    def _is_scopable_node(self, node):
34
        scopes = [ast.Module, ast.FunctionDef, ast.For, ast.If, ast.With]
35
        return len([s for s in scopes if isinstance(node, s)]) > 0
36
37
38
class ScopeList(list):
39
    """
40
    Wraps around list of scopes and provides find method for finding
41
    the definition of a variable
42
    """
43
    def find(self, lookup):
44
        """Find definition of variable lookup."""
45
        def is_match(var):
46
            return ((isinstance(var, ast.alias) and var.name == lookup) or
47
                    (isinstance(var, ast.Name) and var.id == lookup))
48
49
        def find_definition(scope, var_attr="vars"):
50
            for var in getattr(scope, var_attr):
51
                if is_match(var):
52
                    return var
53
54
        for scope in self:
55
            defn = find_definition(scope)
56
            if not defn and hasattr(scope, "body_vars"):
57
                defn = find_definition(scope, "body_vars")
58
            if not defn and hasattr(scope, "orelse_vars"):
59
                defn = find_definition(scope, "orelse_vars")
60
            if defn:
61
                return defn
62
63
    def find_import(self, lookup):
64
        for scope in reversed(self):
65
            if hasattr(scope, "imports"):
66
                for imp in scope.imports:
67
                    if imp.name == lookup:
68
                        return imp
69
70
71
class ScopeTransformer(ast.NodeTransformer, ScopeMixin):
72
    """
73
    Adds a scope attribute to each node.
74
    The scope contains the current scope (function, module, for loop)
75
    a node is part of.
76
    """
77
78
    def visit(self, node):
79
        with self.enter_scope(node):
80
            node.scopes = ScopeList(self.scopes)
81
            return super(ScopeTransformer, self).visit(node)
82