transpile()   B
last analyzed

Complexity

Conditions 5

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 5
c 5
b 0
f 0
dl 0
loc 27
rs 8.0894
1
import ast
2
from clike import CLikeTranspiler
3
from scope import add_scope_context
4
from analysis import add_imports, is_void_function
5
from context import add_variable_context, add_list_calls
6
from tracer import decltype, is_list, is_builtin_import, defined_before
7
8
9
def transpile(source, headers=False, testing=False):
10
    """
11
    Transpile a single python translation unit (a python script) into
12
    C++ 14 code.
13
    """
14
    tree = ast.parse(source)
15
    add_variable_context(tree)
16
    add_scope_context(tree)
17
    add_list_calls(tree)
18
    add_imports(tree)
19
20
    transpiler = CppTranspiler()
21
22
    buf = []
23
    if testing:
24
        buf += ['#include "catch.hpp"']
25
        transpiler.use_catch_test_cases = True
26
27
    if headers:
28
        buf += transpiler.headers
29
        buf += transpiler.usings
30
31
    if testing or headers:
32
        buf.append('')  # Force empty line
33
34
    cpp = transpiler.visit(tree)
35
    return "\n".join(buf) + cpp
36
37
38
def generate_catch_test_case(node, body):
39
    funcdef = 'TEST_CASE("{0}")'.format(node.name)
40
    return funcdef + " {\n" + body + "\n}"
41
42
43
def generate_template_fun(node, body):
44
    params = []
45
    for idx, arg in enumerate(node.args.args):
46
        params.append(("T" + str(idx + 1), arg.id))
47
    typenames = ["typename " + arg[0] for arg in params]
48
49
    template = "inline "
50
    if len(typenames) > 0:
51
        template = "template <{0}>\n".format(", ".join(typenames))
52
    params = ["{0} {1}".format(arg[0], arg[1]) for arg in params]
53
54
    return_type = "auto"
55
    if is_void_function(node):
56
        return_type = "void"
57
58
    funcdef = "{0}{1} {2}({3})".format(template, return_type, node.name,
59
                                          ", ".join(params))
60
    return funcdef + " {\n" + body + "\n}"
61
62
63
def generate_lambda_fun(node, body):
64
    params = ["auto {0}".format(param.id) for param in node.args.args]
65
    funcdef = "auto {0} = []({1})".format(node.name, ", ".join(params))
66
    return funcdef + " {\n" + body + "\n};"
67
68
69
class CppTranspiler(CLikeTranspiler):
70
    def __init__(self):
71
        self.headers = set(['#include "sys.h"', '#include "builtins.h"',
72
                            '#include <iostream>', '#include <string>',
73
                            '#include <algorithm>', '#include <cmath>',
74
                            '#include <vector>', '#include <tuple>',
75
                            '#include <utility>', '#include "range.hpp"'])
76
        self.usings = set([])
77
        self.use_catch_test_cases = False
78
79
    def visit_FunctionDef(self, node):
80
        body = "\n".join([self.visit(n) for n in node.body])
81
82
        if (self.use_catch_test_cases and
83
            is_void_function(node) and
84
            node.name.startswith("test")):
85
            return generate_catch_test_case(node, body)
86
        # is_void_function(node) or is_recursive(node):
87
        return generate_template_fun(node, body)
88
        # else:
89
        #    return generate_lambda_fun(node, body)
90
91
    def visit_Attribute(self, node):
92
        attr = node.attr
93
        if is_builtin_import(node.value.id):
94
            return "py14::" + node.value.id + "::" + attr
95
        elif node.value.id == "math":
96
            if node.attr == "asin":
97
                return "std::asin"
98
            elif node.attr == "atan":
99
                return "std::atan"
100
            elif node.attr == "acos":
101
                return "std::acos"
102
103
        if is_list(node.value):
104
            if node.attr == "append":
105
                attr = "push_back"
106
        return node.value.id + "." + attr
107
108
    def visit_Call(self, node):
109
        fname = self.visit(node.func)
110
        if node.args:
111
            args = [self.visit(a) for a in node.args]
112
            args = ", ".join(args)
113
        else:
114
            args = ''
115
116
        if fname == "int":
117
            return "py14::to_int({0})".format(args)
118
        elif fname == "str":
119
            return "std::to_string({0})".format(args)
120
        elif fname == "max":
121
            return "std::max({0})".format(args)
122
        elif fname == "range":
123
            return "rangepp::range({0})".format(args)
124
        elif fname == "xrange":
125
            return "rangepp::xrange({0})".format(args)
126
        elif fname == "len":
127
            return "{0}.size()".format(self.visit(node.args[0]))
128
129
        return '{0}({1})'.format(fname, args)
130
131
    def visit_For(self, node):
132
        target = self.visit(node.target)
133
        it = self.visit(node.iter)
134
        buf = []
135
        buf.append('for(auto {0} : {1}) {{'.format(target, it))
136
        buf.extend([self.visit(c) for c in node.body])
137
        buf.append("}")
138
        return "\n".join(buf)
139
140
    def visit_Expr(self, node):
141
        s = self.visit(node.value)
142
        if s.strip() and not s.endswith(';'):
143
            s += ';'
144
        if s == ';':
145
            return ''
146
        else:
147
            return s
148
149
    def visit_Str(self, node):
150
        """Use a C++ 14 string literal instead of raw string"""
151
        return ("std::string {" +
152
                super(CppTranspiler, self).visit_Str(node) + "}")
153
154
    def visit_Name(self, node):
155
        if node.id == 'None':
156
            return 'nullptr'
157
        else:
158
            return super(CppTranspiler, self).visit_Name(node)
159
160
    def visit_If(self, node):
161
        body_vars = set([v.id for v in node.scopes[-1].body_vars])
162
        orelse_vars = set([v.id for v in node.scopes[-1].orelse_vars])
163
        node.common_vars = body_vars.intersection(orelse_vars)
164
165
        var_definitions = []
166
        for cv in node.common_vars:
167
            definition = node.scopes.find(cv)
168
            var_type = decltype(definition)
169
            var_definitions.append("{0} {1};\n".format(var_type, cv))
170
171
        if self.visit(node.test) == '__name__ == std::string {"__main__"}':
172
            buf = ["int main(int argc, char ** argv) {",
173
                   "py14::sys::argv = "
174
                   "std::vector<std::string>(argv, argv + argc);"]
175
            buf.extend([self.visit(child) for child in node.body])
176
            buf.append("}")
177
            return "\n".join(buf)
178
        else:
179
            return ("".join(var_definitions) +
180
                    super(CppTranspiler, self).visit_If(node))
181
182
    def visit_BinOp(self, node):
183
        if (isinstance(node.left, ast.List)
184
                and isinstance(node.op, ast.Mult)
185
                and isinstance(node.right, ast.Num)):
186
            return "std::vector ({0},{1})".format(self.visit(node.right),
187
                                                  self.visit(node.left.elts[0]))
188
        else:
189
            return super(CppTranspiler, self).visit_BinOp(node)
190
191
    def visit_Module(self, node):
192
        buf = [self.visit(b) for b in node.body]
193
        return "\n".join(buf)
194
195
    def visit_alias(self, node):
196
        return '#include "{0}.h"'.format(node.name)
197
198
    def visit_Import(self, node):
199
        imports = [self.visit(n) for n in node.names]
200
        return "\n".join(filter(None, imports))
201
202
    def visit_List(self, node):
203
        if len(node.elts) > 0:
204
            elements = [self.visit(e) for e in node.elts]
205
            value_type = decltype(node.elts[0])
206
            return "std::vector<{0}>{{{1}}}".format(value_type,
207
                                                    ", ".join(elements))
208
209
        else:
210
            raise ValueError("Cannot create vector without elements")
211
212
    def visit_Subscript(self, node):
213
        if isinstance(node.slice, ast.Ellipsis):
214
            raise NotImplementedError('Ellipsis not supported')
215
216
        if not isinstance(node.slice, ast.Index):
217
            raise NotImplementedError("Advanced Slicing not supported")
218
219
        value = self.visit(node.value)
220
        return "{0}[{1}]".format(value, self.visit(node.slice.value))
221
222
    def visit_Tuple(self, node):
223
        elts = [self.visit(e) for e in node.elts]
224
        return "std::make_tuple({0})".format(", ".join(elts))
225
226
    def visit_TryExcept(self, node, finallybody=None):
227
        buf = ['try {']
228
        buf += [self.visit(n) for n in node.body]
229
        buf.append('} catch (const std::exception& e) {')
230
231
        buf += [self.visit(h) for h in node.handlers]
232
233
        if finallybody:
234
            buf.append('try { // finally')
235
            buf += [self.visit(b) for b in finallybody]
236
            buf.append('} throw e;')
237
238
        buf.append('}')
239
        buf.append('catch (const std::overflow_error& e) '
240
                   '{ std::cout << "OVERFLOW ERROR" << std::endl; }')
241
        buf.append('catch (const std::runtime_error& e) '
242
                   '{ std::cout << "RUNTIME ERROR" << std::endl; }')
243
        buf.append('catch (...) '
244
                   '{ std::cout << "UNKNOWN ERROR" << std::endl; 0}')
245
246
        return '\n'.join(buf)
247
248
    def visit_Assert(self, node):
249
        return "REQUIRE({0});".format(self.visit(node.test))
250
251
    def visit_Assign(self, node):
252
        target = node.targets[0]
253
254
        if isinstance(target, ast.Tuple):
255
            elts = [self.visit(e) for e in target.elts]
256
            value = self.visit(node.value)
257
            return "std::tie({0}) = {1};".format(", ".join(elts), value)
258
259
        if isinstance(node.scopes[-1], ast.If):
260
            outer_if = node.scopes[-1]
261
            if target.id in outer_if.common_vars:
262
                value = self.visit(node.value)
263
                return "{0} = {1};".format(target.id, value)
264
265
        if isinstance(target, ast.Subscript):
266
            target = self.visit(target)
267
            value = self.visit(node.value)
268
            return "{0} = {1};".format(target, value)
269
270
        definition = node.scopes.find(target.id)
271
        if (isinstance(target, ast.Name) and
272
              defined_before(definition, node)):
273
            target = self.visit(target)
274
            value = self.visit(node.value)
275
            return "{0} = {1};".format(target, value)
276
        elif isinstance(node.value, ast.List):
277
            elements = [self.visit(e) for e in node.value.elts]
278
            return "{0} {1} {{{2}}};".format(decltype(node),
279
                                             self.visit(target),
280
                                             ", ".join(elements))
281
        else:
282
            target = self.visit(target)
283
            value = self.visit(node.value)
284
            return "auto {0} = {1};".format(target, value)
285
286
    def visit_Print(self, node):
287
        buf = []
288
        for n in node.values:
289
            value = self.visit(n)
290
            if isinstance(n, ast.List) or isinstance(n, ast.Tuple):
291
                buf.append("std::cout << {0} << std::endl;".format(
292
                           " << ".join([self.visit(el) for el in n.elts])))
293
            else:
294
                buf.append('std::cout << {0} << std::endl;'.format(value))
295
        return '\n'.join(buf)
296