FunctionCollector.visit_FunctionDef()   B
last analyzed

Complexity

Conditions 8

Size

Total Lines 43
Code Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 37
nop 2
dl 0
loc 43
rs 7.1253
c 0
b 0
f 0
1
import dataclasses
2
import json
3
import sys
4
from dataclasses import dataclass
5
from typing import Optional
6
7
import libcst as cst
8
from libcst.metadata import PositionProvider, WhitespaceInclusivePositionProvider
9
10
KIND_FN = "fn"
11
KIND_PARAM = "param"
12
KIND_SIGNATURE = "sig"
13
14
PROP_NAME = "name"
15
PROP_TYPE = "type"
16
PROP_RETURN_TYPE = "return_type"
17
PROP_DEFAULT = "default"
18
19
20
class DataClassJsonEncoder(json.JSONEncoder):
21
    def default(self, o):
22
        if callable(getattr(o, "as_dict", None)):
23
            return o.as_dict()
24
        if dataclasses.is_dataclass(o):
25
            return dataclasses.asdict(o)
26
        return super().default(o)
27
28
29
@dataclass
30
class SourcePosition:
31
    line: int
32
    col: int
33
34
    def render(self):
35
        return f"{self.line}:{self.col}"
36
37
38
@dataclass
39
class SourceRange:
40
    start: SourcePosition
41
    end: SourcePosition
42
43
    def render(self):
44
        return f"{self.start.render()}-{self.end.render()}"
45
46
47
@dataclass
48
class AstNode:
49
    kind: str
50
    src_range: SourceRange
51
    props: dict[str, str]
52
    children: list["AstNode"]  # forward declaration
53
    text: Optional[str]
54
55
    def __init__(self, kind, src_range):
56
        self.kind = kind
57
        self.src_range = src_range
58
        self.props = {}
59
        self.children = []
60
        self.text = None
61
62
    def as_dict(self) -> dict:
63
        d = {
64
            "kind": self.kind,
65
            "range": self.src_range.render(),
66
        }
67
        if self.props:
68
            d["props"] = self.props
69
        if self.children:
70
            d["children"] = self.children
71
        if self.text:
72
            d["text"] = self.text
73
        return d
74
75
    def children_filtered(self, kind):
76
        return [child for child in self.children if child.kind == kind]
77
78
79
class FunctionCollector(cst.CSTVisitor):
80
    METADATA_DEPENDENCIES = (
81
        PositionProvider,
82
        WhitespaceInclusivePositionProvider,
83
    )
84
85
    def __init__(self, enclosing_module, copy_function_text=False):
86
        # stack for storing the canonical name of the current function
87
        super().__init__()
88
        self.stack: list[str] = []
89
        # store the annotations
90
        self.annotations: dict[
91
            tuple[str, ...],  # key: tuple of canonical class/function name
92
            tuple[cst.Parameters, Optional[cst.Annotation]],  # value: (params, returns)
93
        ] = {}
94
        self.enclosing_module = enclosing_module
95
        self.copy_function_text: bool = copy_function_text
96
        self.function_asts: list[AstNode] = []
97
98
    def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
99
        self.stack.append(node.name.value)
100
        return None
101
102
    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
103
        self.stack.pop()
104
105
    def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
106
        name = node.name.value
107
        self.stack.append(name)
108
        self.annotations[tuple(self.stack)] = (node.params, node.returns)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
109
        src_range = self._src_range(node)
110
        signature_start = src_range.start
111
112
        return_text = ""
113
        fn_ast = AstNode(kind=KIND_FN, src_range=src_range)
114
        if node.returns:
115
            signature_end = self._src_range(node.returns).end
116
            return_type = self.enclosing_module.code_for_node(node.returns.annotation)
117
            return_text = " -> " + return_type
118
            fn_ast.props[PROP_RETURN_TYPE] = return_type
119
        else:
120
            # Include space and add 1 to get after the paren.
121
            signature_end = self._src_range(node.params, include_whitespace=True).end
122
            signature_end.col = signature_end.col + 1
123
        signature_range = SourceRange(signature_start, signature_end)
124
        qname = ".".join(tuple(self.stack))
125
126
        fn_ast.props[PROP_NAME] = qname
127
        signature_ast = AstNode(kind=KIND_SIGNATURE, src_range=signature_range)
128
        param_text = self.enclosing_module.code_for_node(node.params)
129
        signature_ast.text = f"def {name}({param_text}){return_text}"
130
        fn_ast.children.append(signature_ast)
131
        if self.copy_function_text:
132
            fn_ast.text = self.enclosing_module.code_for_node(node)
133
        for param in node.params.params:
134
            signature_ast.children.append(self._param_node_to_ast(param))
135
        if isinstance(node.params.star_arg, cst.Param):
136
            signature_ast.children.append(self._param_node_to_ast(node.params.star_arg))
137
        for param in node.params.kwonly_params:
138
            signature_ast.children.append(self._param_node_to_ast(param))
139
        if node.params.star_kwarg:
140
            signature_ast.children.append(
141
                self._param_node_to_ast(node.params.star_kwarg)
142
            )
143
        for param in node.params.posonly_params:
144
            signature_ast.children.append(self._param_node_to_ast(param))
145
        self.function_asts.append(fn_ast)
146
        # Skipping inner functions
147
        return False
148
149
    def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
150
        self.stack.pop()
151
152
    def _src_range(self, node: cst.CSTNode, include_whitespace=False):
153
        if include_whitespace:
154
            cst_range = self.get_metadata(
155
                cst.metadata.WhitespaceInclusivePositionProvider, node
156
            )
157
        else:
158
            cst_range = self.get_metadata(cst.metadata.PositionProvider, node)
159
        return SourceRange(
160
            start=SourcePosition(
161
                line=cst_range.start.line, col=cst_range.start.column + 1
162
            ),
163
            end=SourcePosition(line=cst_range.end.line, col=cst_range.end.column + 1),
164
        )
165
166
    def _param_node_to_ast(self, param_node: cst.Param) -> AstNode:
167
        param_range = self._src_range(param_node)
168
        param_name: str = param_node.name.value
169
        annotation = param_node.annotation
170
        ast = AstNode(kind=KIND_PARAM, src_range=param_range)
171
        if annotation:
172
            ast.props[PROP_TYPE] = self.enclosing_module.code_for_node(
173
                annotation.annotation
174
            )
175
        ast.props[PROP_NAME] = param_name
176
        if param_node.default:
177
            ast.props[PROP_DEFAULT] = self.enclosing_module.code_for_node(
178
                param_node.default
179
            )
180
        return ast
181
182
    # def visit_Name(self, node: cst.Name) -> None:
183
    #     # Only print out names that are parameters
184
    #     if self.get_metadata(IsParamProvider, node):
185
    #         pos = self.get_metadata(PositionProvider, node).start
186
    #         print(f"{node.value} found at line {pos.line}, column {pos.column}")
187
188
189
def collect_function_asts(code: str):
190
    module = cst.parse_module(code)
191
    wrapper = cst.metadata.MetadataWrapper(module)
192
    visitor = FunctionCollector(module, copy_function_text=True)
193
    wrapper.visit(visitor)
194
    return visitor.function_asts
195
196
197
def to_json(o):
198
    return json.dumps(o, cls=DataClassJsonEncoder, indent=2)
199
200
201
def _main():
202
    print("Reading", sys.argv[0])
203
    with open(sys.argv[0], "r") as file:
204
        code = file.read()
205
    for ast in collect_function_asts(code):
206
        print(to_json(ast))
207
208
209
if __name__ == "__main__":
210
    _main()
211