|
1
|
|
|
import sys |
|
2
|
|
|
from typing import Optional |
|
3
|
|
|
import dataclasses |
|
4
|
|
|
from dataclasses import dataclass |
|
5
|
|
|
import json |
|
6
|
|
|
import libcst as cst |
|
7
|
|
|
from libcst.metadata import PositionProvider, WhitespaceInclusivePositionProvider |
|
8
|
|
|
|
|
9
|
|
|
KIND_FN = 'fn' |
|
10
|
|
|
KIND_PARAM = 'param' |
|
11
|
|
|
KIND_SIGNATURE = 'sig' |
|
12
|
|
|
|
|
13
|
|
|
PROP_NAME = 'name' |
|
14
|
|
|
PROP_TYPE = 'type' |
|
15
|
|
|
PROP_RETURN_TYPE = 'return_type' |
|
16
|
|
|
PROP_DEFAULT = 'default' |
|
17
|
|
|
|
|
18
|
|
|
class DataClassJsonEncoder(json.JSONEncoder): |
|
19
|
|
|
def default(self, o): |
|
20
|
|
|
if callable(getattr(o, "as_dict", None)): |
|
21
|
|
|
return o.as_dict() |
|
22
|
|
|
if dataclasses.is_dataclass(o): |
|
23
|
|
|
return dataclasses.asdict(o) |
|
24
|
|
|
return super().default(o) |
|
25
|
|
|
|
|
26
|
|
|
|
|
27
|
|
|
@dataclass |
|
28
|
|
|
class SourcePosition: |
|
29
|
|
|
line: int |
|
30
|
|
|
col: int |
|
31
|
|
|
|
|
32
|
|
|
def render(self): |
|
33
|
|
|
return f'{self.line}:{self.col}' |
|
34
|
|
|
|
|
35
|
|
|
|
|
36
|
|
|
@dataclass |
|
37
|
|
|
class SourceRange: |
|
38
|
|
|
start: SourcePosition |
|
39
|
|
|
end: SourcePosition |
|
40
|
|
|
|
|
41
|
|
|
def render(self): |
|
42
|
|
|
return f'{self.start.render()}-{self.end.render()}' |
|
43
|
|
|
|
|
44
|
|
|
|
|
45
|
|
|
@dataclass |
|
46
|
|
|
class AstNode: |
|
47
|
|
|
kind: str |
|
48
|
|
|
src_range: SourceRange |
|
49
|
|
|
props: dict[str, str] |
|
50
|
|
|
children: list["AstNode"] # forward declaration |
|
51
|
|
|
text: Optional[str] |
|
52
|
|
|
|
|
53
|
|
|
def __init__(self, kind, src_range): |
|
54
|
|
|
self.kind = kind |
|
55
|
|
|
self.src_range = src_range |
|
56
|
|
|
self.props = {} |
|
57
|
|
|
self.children = [] |
|
58
|
|
|
self.text = None |
|
59
|
|
|
|
|
60
|
|
|
def as_dict(self) -> dict: |
|
61
|
|
|
d = { |
|
62
|
|
|
'kind': self.kind, |
|
63
|
|
|
'range': self.src_range.render(), |
|
64
|
|
|
} |
|
65
|
|
|
if self.props: |
|
66
|
|
|
d['props'] = self.props |
|
67
|
|
|
if self.children: |
|
68
|
|
|
d['children'] = self.children |
|
69
|
|
|
if self.text: |
|
70
|
|
|
d['text'] = self.text |
|
71
|
|
|
return d |
|
72
|
|
|
|
|
73
|
|
|
def children_filtered(self, kind): |
|
74
|
|
|
return [child for child in self.children if child.kind == kind] |
|
75
|
|
|
|
|
76
|
|
|
|
|
77
|
|
|
class FunctionCollector(cst.CSTVisitor): |
|
78
|
|
|
METADATA_DEPENDENCIES = (PositionProvider,WhitespaceInclusivePositionProvider,) |
|
79
|
|
|
def __init__(self, enclosing_module, copy_function_text=False): |
|
80
|
|
|
# stack for storing the canonical name of the current function |
|
81
|
|
|
super().__init__() |
|
82
|
|
|
self.stack: list[str] = [] |
|
83
|
|
|
# store the annotations |
|
84
|
|
|
self.annotations: dict[ |
|
85
|
|
|
tuple[str, ...], # key: tuple of canonical class/function name |
|
86
|
|
|
tuple[cst.Parameters, Optional[cst.Annotation]], # value: (params, returns) |
|
87
|
|
|
] = {} |
|
88
|
|
|
self.enclosing_module = enclosing_module |
|
89
|
|
|
self.copy_function_text: bool = copy_function_text |
|
90
|
|
|
self.function_asts: list[AstNode] = [] |
|
91
|
|
|
|
|
92
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: |
|
93
|
|
|
self.stack.append(node.name.value) |
|
94
|
|
|
return None |
|
95
|
|
|
|
|
96
|
|
|
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: |
|
97
|
|
|
self.stack.pop() |
|
98
|
|
|
|
|
99
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: |
|
100
|
|
|
name = node.name.value |
|
101
|
|
|
self.stack.append(name) |
|
102
|
|
|
self.annotations[tuple(self.stack)] = (node.params, node.returns) |
|
|
|
|
|
|
103
|
|
|
src_range = self._src_range(node) |
|
104
|
|
|
signature_start = src_range.start |
|
105
|
|
|
|
|
106
|
|
|
|
|
107
|
|
|
return_text = "" |
|
108
|
|
|
fn_ast = AstNode(kind=KIND_FN, src_range=src_range) |
|
109
|
|
|
if node.returns: |
|
110
|
|
|
signature_end = self._src_range(node.returns).end |
|
111
|
|
|
return_type = self.enclosing_module.code_for_node(node.returns.annotation) |
|
112
|
|
|
return_text = " -> " + return_type |
|
113
|
|
|
fn_ast.props[PROP_RETURN_TYPE] = return_type |
|
114
|
|
|
else: |
|
115
|
|
|
# Include space and add 1 to get after the paren. |
|
116
|
|
|
signature_end = self._src_range(node.params, include_whitespace=True).end |
|
117
|
|
|
signature_end.col = signature_end.col + 1 |
|
118
|
|
|
signature_range = SourceRange(signature_start, signature_end) |
|
119
|
|
|
qname = '.'.join(tuple(self.stack)) |
|
120
|
|
|
|
|
121
|
|
|
fn_ast.props[PROP_NAME] = qname |
|
122
|
|
|
signature_ast = AstNode(kind=KIND_SIGNATURE, src_range=signature_range) |
|
123
|
|
|
param_text = self.enclosing_module.code_for_node(node.params) |
|
124
|
|
|
signature_ast.text = f'def {name}({param_text}){return_text}' |
|
125
|
|
|
fn_ast.children.append(signature_ast) |
|
126
|
|
|
if self.copy_function_text: |
|
127
|
|
|
fn_ast.text = self.enclosing_module.code_for_node(node) |
|
128
|
|
|
for param in node.params.params: |
|
129
|
|
|
signature_ast.children.append(self._param_node_to_ast(param)) |
|
130
|
|
|
if isinstance(node.params.star_arg, cst.Param): |
|
131
|
|
|
signature_ast.children.append(self._param_node_to_ast(node.params.star_arg)) |
|
132
|
|
|
for param in node.params.kwonly_params: |
|
133
|
|
|
signature_ast.children.append(self._param_node_to_ast(param)) |
|
134
|
|
|
if node.params.star_kwarg: |
|
135
|
|
|
signature_ast.children.append(self._param_node_to_ast(node.params.star_kwarg)) |
|
136
|
|
|
for param in node.params.posonly_params: |
|
137
|
|
|
signature_ast.children.append(self._param_node_to_ast(param)) |
|
138
|
|
|
self.function_asts.append(fn_ast) |
|
139
|
|
|
# Skipping inner functions |
|
140
|
|
|
return ( |
|
141
|
|
|
False |
|
142
|
|
|
) |
|
143
|
|
|
|
|
144
|
|
|
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: |
|
145
|
|
|
self.stack.pop() |
|
146
|
|
|
|
|
147
|
|
|
def _src_range(self, node: cst.CSTNode, include_whitespace=False): |
|
148
|
|
|
if include_whitespace: |
|
149
|
|
|
cst_range = self.get_metadata(cst.metadata.WhitespaceInclusivePositionProvider, node) |
|
150
|
|
|
else: |
|
151
|
|
|
cst_range = self.get_metadata(cst.metadata.PositionProvider, node) |
|
152
|
|
|
return SourceRange( |
|
153
|
|
|
start=SourcePosition(line=cst_range.start.line, col=cst_range.start.column + 1), |
|
154
|
|
|
end=SourcePosition(line=cst_range.end.line, col=cst_range.end.column + 1)) |
|
155
|
|
|
|
|
156
|
|
|
def _param_node_to_ast(self, param_node: cst.Param) -> AstNode: |
|
157
|
|
|
param_range = self._src_range(param_node) |
|
158
|
|
|
param_name: str = param_node.name.value |
|
159
|
|
|
annotation = param_node.annotation |
|
160
|
|
|
ast = AstNode(kind=KIND_PARAM, src_range=param_range) |
|
161
|
|
|
if annotation: |
|
162
|
|
|
ast.props[PROP_TYPE] = self.enclosing_module.code_for_node(annotation.annotation) |
|
163
|
|
|
ast.props[PROP_NAME] = param_name |
|
164
|
|
|
if param_node.default: |
|
165
|
|
|
ast.props[PROP_DEFAULT] = self.enclosing_module.code_for_node(param_node.default) |
|
166
|
|
|
return ast |
|
167
|
|
|
|
|
168
|
|
|
# def visit_Name(self, node: cst.Name) -> None: |
|
169
|
|
|
# # Only print out names that are parameters |
|
170
|
|
|
# if self.get_metadata(IsParamProvider, node): |
|
171
|
|
|
# pos = self.get_metadata(PositionProvider, node).start |
|
172
|
|
|
# print(f"{node.value} found at line {pos.line}, column {pos.column}") |
|
173
|
|
|
|
|
174
|
|
|
|
|
175
|
|
|
def collect_function_asts(code: str): |
|
176
|
|
|
module = cst.parse_module(code) |
|
177
|
|
|
wrapper = cst.metadata.MetadataWrapper(module) |
|
178
|
|
|
visitor = FunctionCollector(module, copy_function_text=True) |
|
179
|
|
|
wrapper.visit(visitor) |
|
180
|
|
|
return visitor.function_asts |
|
181
|
|
|
|
|
182
|
|
|
|
|
183
|
|
|
def to_json(o): |
|
184
|
|
|
return json.dumps(o, cls=DataClassJsonEncoder, indent=2) |
|
185
|
|
|
|
|
186
|
|
|
|
|
187
|
|
|
def _main(): |
|
188
|
|
|
print("Reading", sys.argv[0]) |
|
189
|
|
|
with open(sys.argv[0], 'r') as file: |
|
190
|
|
|
code = file.read() |
|
191
|
|
|
for ast in collect_function_asts(code): |
|
192
|
|
|
print(to_json(ast)) |
|
193
|
|
|
|
|
194
|
|
|
|
|
195
|
|
|
if __name__ == "__main__": |
|
196
|
|
|
_main() |
|
197
|
|
|
|