Completed
Pull Request — master (#1)
by Valentin
02:24
created

CppTranspiler.visit_UnaryOp()   A

Complexity

Conditions 3

Size

Total Lines 9

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 9
rs 9.6666
1
import sys
2
import ast
3
from .clike import CLikeTranspiler
4
from .scope import add_scope_context
5
from .context import add_variable_context, add_list_calls
6
from .analysis import add_imports, is_void_function, get_id
7
from .tracer import decltype, is_list, is_builtin_import, defined_before
8
9
10
def transpile(source, headers=False, testing=False):
11
    """
12
    Transpile a single python translation unit (a python script) into
13
    C++ 14 code.
14
    """
15
    tree = ast.parse(source)
16
    add_variable_context(tree)
17
    add_scope_context(tree)
18
    add_list_calls(tree)
19
    add_imports(tree)
20
21
    transpiler = CppTranspiler()
22
23
    buf = []
24
    if testing:
25
        buf += ['#include "catch.hpp"']
26
        transpiler.use_catch_test_cases = True
27
28
    if headers:
29
        buf += transpiler.headers
30
        buf += transpiler.usings
31
32
    if testing or headers:
33
        buf.append('')  # Force empty line
34
35
    cpp = transpiler.visit(tree)
36
    return "\n".join(buf) + cpp
37
38
39
def generate_catch_test_case(node, body):
40
    funcdef = 'TEST_CASE("{0}")'.format(node.name)
41
    return funcdef + " {\n" + body + "\n}"
42
43
44
def generate_template_fun(node, body):
45
    params = []
46
    for idx, arg in enumerate(node.args.args):
47
        params.append(("T" + str(idx + 1), get_id(arg)))
48
    typenames = ["typename " + arg[0] for arg in params]
49
50
    template = "inline "
51
    if len(typenames) > 0:
52
        template = "template <{0}>\n".format(", ".join(typenames))
53
    params = ["{0} {1}".format(arg[0], arg[1]) for arg in params]
54
55
    return_type = "auto"
56
    if is_void_function(node):
57
        return_type = "void"
58
59
    funcdef = "{0}{1} {2}({3})".format(template, return_type, node.name,
60
                                          ", ".join(params))
61
    return funcdef + " {\n" + body + "\n}"
62
63
64
def generate_lambda_fun(node, body):
65
    params = ["auto {0}".format(param.id) for param in node.args.args]
66
    funcdef = "auto {0} = []({1})".format(node.name, ", ".join(params))
67
    return funcdef + " {\n" + body + "\n};"
68
69
70
class CppTranspiler(CLikeTranspiler):
71
    def __init__(self):
72
        self.headers = set(['#include "sys.h"', '#include "builtins.h"',
73
                            '#include <iostream>', '#include <string>',
74
                            '#include <algorithm>', '#include <cmath>',
75
                            '#include <vector>', '#include <tuple>',
76
                            '#include <utility>', '#include "range.hpp"'])
77
        self.usings = set([])
78
        self.use_catch_test_cases = False
79
80
    def visit_FunctionDef(self, node):
81
        body = "\n".join([self.visit(n) for n in node.body])
82
83
        if (self.use_catch_test_cases and
84
            is_void_function(node) and
85
            node.name.startswith("test")):
86
            return generate_catch_test_case(node, body)
87
        # is_void_function(node) or is_recursive(node):
88
        return generate_template_fun(node, body)
89
        # else:
90
        #    return generate_lambda_fun(node, body)
91
92
    def visit_Attribute(self, node):
93
        attr = node.attr
94
        value_id = get_id(node.value)
95
        if is_builtin_import(value_id):
96
            return "py14::" + value_id + "::" + attr
97
        elif value_id == "math":
98
            if node.attr == "asin":
99
                return "std::asin"
100
            elif node.attr == "atan":
101
                return "std::atan"
102
            elif node.attr == "acos":
103
                return "std::acos"
104
105
        if is_list(node.value):
106
            if node.attr == "append":
107
                attr = "push_back"
108
        return value_id + "." + attr
109
110
    def visit_Call(self, node):
111
        fname = self.visit(node.func)
112
        if node.args:
113
            args = [self.visit(a) for a in node.args]
114
            args = ", ".join(args)
115
        else:
116
            args = ''
117
118
        if fname == "int":
119
            return "py14::to_int({0})".format(args)
120
        elif fname == "str":
121
            return "std::to_string({0})".format(args)
122
        elif fname == "max":
123
            return "std::max({0})".format(args)
124
        elif fname == "range":
125
            if sys.version_info[0] >= 3:
126
                return "rangepp::xrange({0})".format(args)
127
            else:
128
                return "rangepp::range({0})".format(args)
129
        elif fname == "xrange":
130
            return "rangepp::xrange({0})".format(args)
131
        elif fname == "len":
132
            return "{0}.size()".format(self.visit(node.args[0]))
133
        elif fname == "print":
134
            buf = []
135
            for n in node.args:
136
                value = self.visit(n)
137
                if isinstance(n, ast.List) or isinstance(n, ast.Tuple):
138
                    buf.append("std::cout << {0} << std::endl;".format(
139
                               " << ".join([self.visit(el) for el in n.elts])))
140
                else:
141
                    buf.append('std::cout << {0} << std::endl;'.format(value))
142
            return '\n'.join(buf)
143
144
        return '{0}({1})'.format(fname, args)
145
146
    def visit_For(self, node):
147
        target = self.visit(node.target)
148
        it = self.visit(node.iter)
149
        buf = []
150
        buf.append('for(auto {0} : {1}) {{'.format(target, it))
151
        buf.extend([self.visit(c) for c in node.body])
152
        buf.append("}")
153
        return "\n".join(buf)
154
155
    def visit_Expr(self, node):
156
        s = self.visit(node.value)
157
        if s.strip() and not s.endswith(';'):
158
            s += ';'
159
        if s == ';':
160
            return ''
161
        else:
162
            return s
163
164
    def visit_Str(self, node):
165
        """Use a C++ 14 string literal instead of raw string"""
166
        return ("std::string {" +
167
                super(CppTranspiler, self).visit_Str(node) + "}")
168
169
    def visit_Name(self, node):
170
        if node.id == 'None':
171
            return 'nullptr'
172
        else:
173
            return super(CppTranspiler, self).visit_Name(node)
174
175
    def visit_NameConstant(self, node):
176
        if node.value is True:
177
            return "true"
178
        elif node.value is False:
179
            return "false"
180
        else:
181
            return super(CppTranspiler, self).visit_NameConstant(node)
182
183
    def visit_If(self, node):
184
        body_vars = set([get_id(v) for v in node.scopes[-1].body_vars])
185
        orelse_vars = set([get_id(v) for v in node.scopes[-1].orelse_vars])
186
        node.common_vars = body_vars.intersection(orelse_vars)
187
188
        var_definitions = []
189
        for cv in node.common_vars:
190
            definition = node.scopes.find(cv)
191
            var_type = decltype(definition)
192
            var_definitions.append("{0} {1};\n".format(var_type, cv))
193
194
        if self.visit(node.test) == '__name__ == std::string {"__main__"}':
195
            buf = ["int main(int argc, char ** argv) {",
196
                   "py14::sys::argv = "
197
                   "std::vector<std::string>(argv, argv + argc);"]
198
            buf.extend([self.visit(child) for child in node.body])
199
            buf.append("}")
200
            return "\n".join(buf)
201
        else:
202
            return ("".join(var_definitions) +
203
                    super(CppTranspiler, self).visit_If(node))
204
205
    def visit_UnaryOp(self, node):
206
        if isinstance(node.op, ast.USub):
207
            if isinstance(node.operand, (ast.Call, ast.Num)):
208
                # Shortcut if parenthesis are not needed
209
                return "-{0}".format(self.visit(node.operand))
210
            else:
211
                return "-({0})".format(self.visit(node.operand))
212
        else:
213
            return super(CppTranspiler, self).visit_UnaryOp(node)
214
215
    def visit_BinOp(self, node):
216
        if (isinstance(node.left, ast.List)
217
                and isinstance(node.op, ast.Mult)
218
                and isinstance(node.right, ast.Num)):
219
            return "std::vector ({0},{1})".format(self.visit(node.right),
220
                                                  self.visit(node.left.elts[0]))
221
        else:
222
            return super(CppTranspiler, self).visit_BinOp(node)
223
224
    def visit_Module(self, node):
225
        buf = [self.visit(b) for b in node.body]
226
        return "\n".join(buf)
227
228
    def visit_alias(self, node):
229
        return '#include "{0}.h"'.format(node.name)
230
231
    def visit_Import(self, node):
232
        imports = [self.visit(n) for n in node.names]
233
        return "\n".join(i for i in imports if i)
234
235
    def visit_List(self, node):
236
        if len(node.elts) > 0:
237
            elements = [self.visit(e) for e in node.elts]
238
            value_type = decltype(node.elts[0])
239
            return "std::vector<{0}>{{{1}}}".format(value_type,
240
                                                    ", ".join(elements))
241
242
        else:
243
            raise ValueError("Cannot create vector without elements")
244
245
    def visit_Subscript(self, node):
246
        if isinstance(node.slice, ast.Ellipsis):
247
            raise NotImplementedError('Ellipsis not supported')
248
249
        if not isinstance(node.slice, ast.Index):
250
            raise NotImplementedError("Advanced Slicing not supported")
251
252
        value = self.visit(node.value)
253
        return "{0}[{1}]".format(value, self.visit(node.slice.value))
254
255
    def visit_Tuple(self, node):
256
        elts = [self.visit(e) for e in node.elts]
257
        return "std::make_tuple({0})".format(", ".join(elts))
258
259
    def visit_TryExcept(self, node, finallybody=None):
260
        buf = ['try {']
261
        buf += [self.visit(n) for n in node.body]
262
        buf.append('} catch (const std::exception& e) {')
263
264
        buf += [self.visit(h) for h in node.handlers]
265
266
        if finallybody:
267
            buf.append('try { // finally')
268
            buf += [self.visit(b) for b in finallybody]
269
            buf.append('} throw e;')
270
271
        buf.append('}')
272
        buf.append('catch (const std::overflow_error& e) '
273
                   '{ std::cout << "OVERFLOW ERROR" << std::endl; }')
274
        buf.append('catch (const std::runtime_error& e) '
275
                   '{ std::cout << "RUNTIME ERROR" << std::endl; }')
276
        buf.append('catch (...) '
277
                   '{ std::cout << "UNKNOWN ERROR" << std::endl; 0}')
278
279
        return '\n'.join(buf)
280
281
    def visit_Assert(self, node):
282
        return "REQUIRE({0});".format(self.visit(node.test))
283
284
    def visit_Assign(self, node):
285
        target = node.targets[0]
286
287
        if isinstance(target, ast.Tuple):
288
            elts = [self.visit(e) for e in target.elts]
289
            value = self.visit(node.value)
290
            return "std::tie({0}) = {1};".format(", ".join(elts), value)
291
292
        if isinstance(node.scopes[-1], ast.If):
293
            outer_if = node.scopes[-1]
294
            if target.id in outer_if.common_vars:
295
                value = self.visit(node.value)
296
                return "{0} = {1};".format(target.id, value)
297
298
        if isinstance(target, ast.Subscript):
299
            target = self.visit(target)
300
            value = self.visit(node.value)
301
            return "{0} = {1};".format(target, value)
302
303
        definition = node.scopes.find(target.id)
304
        if (isinstance(target, ast.Name) and
305
              defined_before(definition, node)):
306
            target = self.visit(target)
307
            value = self.visit(node.value)
308
            return "{0} = {1};".format(target, value)
309
        elif isinstance(node.value, ast.List):
310
            elements = [self.visit(e) for e in node.value.elts]
311
            return "{0} {1} {{{2}}};".format(decltype(node),
312
                                             self.visit(target),
313
                                             ", ".join(elements))
314
        else:
315
            target = self.visit(target)
316
            value = self.visit(node.value)
317
            return "auto {0} = {1};".format(target, value)
318
319
    def visit_Print(self, node):
320
        buf = []
321
        for n in node.values:
322
            value = self.visit(n)
323
            if isinstance(n, ast.List) or isinstance(n, ast.Tuple):
324
                buf.append("std::cout << {0} << std::endl;".format(
325
                           " << ".join([self.visit(el) for el in n.elts])))
326
            else:
327
                buf.append('std::cout << {0} << std::endl;'.format(value))
328
        return '\n'.join(buf)
329