Passed
Push — main ( 46d75e...737583 )
by Ray
01:49
created

DataClassJsonEncoder.default()   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import sys
2
from typing import Optional
3
import dataclasses
4
from dataclasses import dataclass
5
import json
6
import libcst as cst
7
from libcst.metadata import PositionProvider, WhitespaceInclusivePositionProvider
8
9
KIND_FN = 'fn'
10
KIND_PARAM = 'param'
11
KIND_SIGNATURE = 'sig'
12
13
PROP_NAME = 'name'
14
PROP_TYPE = 'type'
15
PROP_RETURN_TYPE = 'return_type'
16
PROP_DEFAULT = 'default'
17
18
class DataClassJsonEncoder(json.JSONEncoder):
19
    def default(self, o):
20
        if callable(getattr(o, "as_dict", None)):
21
            return o.as_dict()
22
        if dataclasses.is_dataclass(o):
23
            return dataclasses.asdict(o)
24
        return super().default(o)
25
26
27
@dataclass
28
class SourcePosition:
29
    line: int
30
    col: int
31
32
    def render(self):
33
        return f'{self.line}:{self.col}'
34
35
36
@dataclass
37
class SourceRange:
38
    start: SourcePosition
39
    end: SourcePosition
40
41
    def render(self):
42
        return f'{self.start.render()}-{self.end.render()}'
43
44
45
@dataclass
46
class AstNode:
47
    kind: str
48
    src_range: SourceRange
49
    props: dict[str, str]
50
    children: list["AstNode"]  # forward declaration
51
    text: Optional[str]
52
53
    def __init__(self, kind, src_range):
54
        self.kind = kind
55
        self.src_range = src_range
56
        self.props = {}
57
        self.children = []
58
        self.text = None
59
60
    def as_dict(self) -> dict:
61
        d = {
62
            'kind': self.kind,
63
            'range': self.src_range.render(),
64
        }
65
        if self.props:
66
            d['props'] = self.props
67
        if self.children:
68
            d['children'] = self.children
69
        if self.text:
70
            d['text'] = self.text
71
        return d
72
73
    def children_filtered(self, kind):
74
        return [child for child in self.children if child.kind == kind]
75
76
77
class FunctionCollector(cst.CSTVisitor):
78
    METADATA_DEPENDENCIES = (PositionProvider,WhitespaceInclusivePositionProvider,)
79
    def __init__(self, enclosing_module, copy_function_text=False):
80
        # stack for storing the canonical name of the current function
81
        super().__init__()
82
        self.stack: list[str] = []
83
        # store the annotations
84
        self.annotations: dict[
85
            tuple[str, ...],  # key: tuple of canonical class/function name
86
            tuple[cst.Parameters, Optional[cst.Annotation]],  # value: (params, returns)
87
        ] = {}
88
        self.enclosing_module = enclosing_module
89
        self.copy_function_text: bool = copy_function_text
90
        self.function_asts: list[AstNode] = []
91
92
    def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
93
        self.stack.append(node.name.value)
94
        return None
95
96
    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
97
        self.stack.pop()
98
99
    def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
100
        name = node.name.value
101
        self.stack.append(name)
102
        self.annotations[tuple(self.stack)] = (node.params, node.returns)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tuple does not seem to be defined.
Loading history...
103
        src_range = self._src_range(node)
104
        signature_start = src_range.start
105
106
107
        return_text = ""
108
        fn_ast = AstNode(kind=KIND_FN, src_range=src_range)
109
        if node.returns:
110
            signature_end = self._src_range(node.returns).end
111
            return_type = self.enclosing_module.code_for_node(node.returns.annotation)
112
            return_text = " -> " + return_type
113
            fn_ast.props[PROP_RETURN_TYPE] = return_type
114
        else:
115
            # Include space and add 1 to get after the paren.
116
            signature_end = self._src_range(node.params, include_whitespace=True).end
117
            signature_end.col = signature_end.col + 1
118
        signature_range = SourceRange(signature_start, signature_end)
119
        qname = '.'.join(tuple(self.stack))
120
121
        fn_ast.props[PROP_NAME] = qname
122
        signature_ast = AstNode(kind=KIND_SIGNATURE, src_range=signature_range)
123
        param_text = self.enclosing_module.code_for_node(node.params)
124
        signature_ast.text = f'def {name}({param_text}){return_text}'
125
        fn_ast.children.append(signature_ast)
126
        if self.copy_function_text:
127
            fn_ast.text = self.enclosing_module.code_for_node(node)
128
        for param in node.params.params:
129
            signature_ast.children.append(self._param_node_to_ast(param))
130
        if isinstance(node.params.star_arg, cst.Param):
131
            signature_ast.children.append(self._param_node_to_ast(node.params.star_arg))
132
        for param in node.params.kwonly_params:
133
            signature_ast.children.append(self._param_node_to_ast(param))
134
        if node.params.star_kwarg:
135
            signature_ast.children.append(self._param_node_to_ast(node.params.star_kwarg))
136
        for param in node.params.posonly_params:
137
            signature_ast.children.append(self._param_node_to_ast(param))
138
        self.function_asts.append(fn_ast)
139
        # Skipping inner functions
140
        return (
141
            False
142
        )
143
144
    def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
145
        self.stack.pop()
146
147
    def _src_range(self, node: cst.CSTNode, include_whitespace=False):
148
        if include_whitespace:
149
            cst_range = self.get_metadata(cst.metadata.WhitespaceInclusivePositionProvider, node)
150
        else:
151
            cst_range = self.get_metadata(cst.metadata.PositionProvider, node)
152
        return SourceRange(
153
            start=SourcePosition(line=cst_range.start.line, col=cst_range.start.column + 1),
154
            end=SourcePosition(line=cst_range.end.line, col=cst_range.end.column + 1))
155
156
    def _param_node_to_ast(self, param_node: cst.Param) -> AstNode:
157
        param_range = self._src_range(param_node)
158
        param_name: str = param_node.name.value
159
        annotation = param_node.annotation
160
        ast = AstNode(kind=KIND_PARAM, src_range=param_range)
161
        if annotation:
162
            ast.props[PROP_TYPE] = self.enclosing_module.code_for_node(annotation.annotation)
163
        ast.props[PROP_NAME] = param_name
164
        if param_node.default:
165
            ast.props[PROP_DEFAULT] = self.enclosing_module.code_for_node(param_node.default)
166
        return ast
167
168
    # def visit_Name(self, node: cst.Name) -> None:
169
    #     # Only print out names that are parameters
170
    #     if self.get_metadata(IsParamProvider, node):
171
    #         pos = self.get_metadata(PositionProvider, node).start
172
    #         print(f"{node.value} found at line {pos.line}, column {pos.column}")
173
174
175
def collect_function_asts(code: str):
176
    module = cst.parse_module(code)
177
    wrapper = cst.metadata.MetadataWrapper(module)
178
    visitor = FunctionCollector(module, copy_function_text=True)
179
    wrapper.visit(visitor)
180
    return visitor.function_asts
181
182
183
def to_json(o):
184
    return json.dumps(o, cls=DataClassJsonEncoder, indent=2)
185
186
187
def _main():
188
    print("Reading", sys.argv[0])
189
    with open(sys.argv[0], 'r') as file:
190
        code = file.read()
191
    for ast in collect_function_asts(code):
192
        print(to_json(ast))
193
194
195
if __name__ == "__main__":
196
    _main()
197