# ------------------------------------------------------------------------------ # pycparser: c_generator.py # # C code generator from pycparser AST nodes. # # Eli Bendersky [https://eli.thegreenplace.net/] # License: BSD # ------------------------------------------------------------------------------ from typing import Callable, List, Optional from . import c_ast class CGenerator: """Uses the same visitor pattern as c_ast.NodeVisitor, but modified to return a value from each visit method, using string accumulation in generic_visit. """ indent_level: int reduce_parentheses: bool def __init__(self, reduce_parentheses: bool = False) -> None: """Constructs C-code generator reduce_parentheses: if True, eliminates needless parentheses on binary operators """ # Statements start with indentation of self.indent_level spaces, using # the _make_indent method. self.indent_level = 0 self.reduce_parentheses = reduce_parentheses def _make_indent(self) -> str: return " " * self.indent_level def visit(self, node: c_ast.Node) -> str: method = "visit_" + node.__class__.__name__ return getattr(self, method, self.generic_visit)(node) def generic_visit(self, node: Optional[c_ast.Node]) -> str: if node is None: return "" else: return "".join(self.visit(c) for c_name, c in node.children()) def visit_Constant(self, n: c_ast.Constant) -> str: return n.value def visit_ID(self, n: c_ast.ID) -> str: return n.name def visit_Pragma(self, n: c_ast.Pragma) -> str: ret = "#pragma" if n.string: ret += " " + n.string return ret def visit_ArrayRef(self, n: c_ast.ArrayRef) -> str: arrref = self._parenthesize_unless_simple(n.name) return arrref + "[" + self.visit(n.subscript) + "]" def visit_StructRef(self, n: c_ast.StructRef) -> str: sref = self._parenthesize_unless_simple(n.name) return sref + n.type + self.visit(n.field) def visit_FuncCall(self, n: c_ast.FuncCall) -> str: fref = self._parenthesize_unless_simple(n.name) args = self.visit(n.args) if n.args is not None else "" return fref + "(" + args + ")" def visit_UnaryOp(self, n: c_ast.UnaryOp) -> str: match n.op: case "sizeof": # Always parenthesize the argument of sizeof since it can be # a name. return f"sizeof({self.visit(n.expr)})" case "p++": operand = self._parenthesize_unless_simple(n.expr) return f"{operand}++" case "p--": operand = self._parenthesize_unless_simple(n.expr) return f"{operand}--" case _: operand = self._parenthesize_unless_simple(n.expr) return f"{n.op}{operand}" # Precedence map of binary operators: precedence_map = { # Should be in sync with c_parser.CParser.precedence # Higher numbers are stronger binding "||": 0, # weakest binding "&&": 1, "|": 2, "^": 3, "&": 4, "==": 5, "!=": 5, ">": 6, ">=": 6, "<": 6, "<=": 6, ">>": 7, "<<": 7, "+": 8, "-": 8, "*": 9, "/": 9, "%": 9, # strongest binding } def visit_BinaryOp(self, n: c_ast.BinaryOp) -> str: # Note: all binary operators are left-to-right associative # # If `n.left.op` has a stronger or equally binding precedence in # comparison to `n.op`, no parenthesis are needed for the left: # e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as # `(a+b) - c` is equivalent to `a+b - c` (same precedence). # If the left operator is weaker binding than the current, then # parentheses are necessary: # e.g., `(a+b) * c` is NOT equivalent to `a+b * c`. lval_str = self._parenthesize_if( n.left, lambda d: not ( self._is_simple_node(d) or self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and self.precedence_map[d.op] >= self.precedence_map[n.op] ), ) # If `n.right.op` has a stronger -but not equal- binding precedence, # parenthesis can be omitted on the right: # e.g., `a + (b*c)` is equivalent to `a + b*c`. # If the right operator is weaker or equally binding, then parentheses # are necessary: # e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and # `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence). rval_str = self._parenthesize_if( n.right, lambda d: not ( self._is_simple_node(d) or self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and self.precedence_map[d.op] > self.precedence_map[n.op] ), ) return f"{lval_str} {n.op} {rval_str}" def visit_Assignment(self, n: c_ast.Assignment) -> str: rval_str = self._parenthesize_if( n.rvalue, lambda n: isinstance(n, c_ast.Assignment) ) return f"{self.visit(n.lvalue)} {n.op} {rval_str}" def visit_IdentifierType(self, n: c_ast.IdentifierType) -> str: return " ".join(n.names) def _visit_expr(self, n: c_ast.Node) -> str: match n: case c_ast.InitList(): return "{" + self.visit(n) + "}" case c_ast.ExprList() | c_ast.Compound(): return "(" + self.visit(n) + ")" case _: return self.visit(n) def visit_Decl(self, n: c_ast.Decl, no_type: bool = False) -> str: # no_type is used when a Decl is part of a DeclList, where the type is # explicitly only for the first declaration in a list. # s = n.name if no_type else self._generate_decl(n) if n.bitsize: s += " : " + self.visit(n.bitsize) if n.init: s += " = " + self._visit_expr(n.init) return s def visit_DeclList(self, n: c_ast.DeclList) -> str: s = self.visit(n.decls[0]) if len(n.decls) > 1: s += ", " + ", ".join( self.visit_Decl(decl, no_type=True) for decl in n.decls[1:] ) return s def visit_Typedef(self, n: c_ast.Typedef) -> str: s = "" if n.storage: s += " ".join(n.storage) + " " s += self._generate_type(n.type) return s def visit_Cast(self, n: c_ast.Cast) -> str: s = "(" + self._generate_type(n.to_type, emit_declname=False) + ")" return s + " " + self._parenthesize_unless_simple(n.expr) def visit_ExprList(self, n: c_ast.ExprList) -> str: visited_subexprs = [] for expr in n.exprs: visited_subexprs.append(self._visit_expr(expr)) return ", ".join(visited_subexprs) def visit_InitList(self, n: c_ast.InitList) -> str: visited_subexprs = [] for expr in n.exprs: visited_subexprs.append(self._visit_expr(expr)) return ", ".join(visited_subexprs) def visit_Enum(self, n: c_ast.Enum) -> str: return self._generate_struct_union_enum(n, name="enum") def visit_Alignas(self, n: c_ast.Alignas) -> str: return "_Alignas({})".format(self.visit(n.alignment)) def visit_Enumerator(self, n: c_ast.Enumerator) -> str: if not n.value: return "{indent}{name},\n".format( indent=self._make_indent(), name=n.name, ) else: return "{indent}{name} = {value},\n".format( indent=self._make_indent(), name=n.name, value=self.visit(n.value), ) def visit_FuncDef(self, n: c_ast.FuncDef) -> str: decl = self.visit(n.decl) self.indent_level = 0 body = self.visit(n.body) if n.param_decls: knrdecls = ";\n".join(self.visit(p) for p in n.param_decls) return decl + "\n" + knrdecls + ";\n" + body + "\n" else: return decl + "\n" + body + "\n" def visit_FileAST(self, n: c_ast.FileAST) -> str: s = "" for ext in n.ext: match ext: case c_ast.FuncDef(): s += self.visit(ext) case c_ast.Pragma(): s += self.visit(ext) + "\n" case _: s += self.visit(ext) + ";\n" return s def visit_Compound(self, n: c_ast.Compound) -> str: s = self._make_indent() + "{\n" self.indent_level += 2 if n.block_items: s += "".join(self._generate_stmt(stmt) for stmt in n.block_items) self.indent_level -= 2 s += self._make_indent() + "}\n" return s def visit_CompoundLiteral(self, n: c_ast.CompoundLiteral) -> str: return "(" + self.visit(n.type) + "){" + self.visit(n.init) + "}" def visit_EmptyStatement(self, n: c_ast.EmptyStatement) -> str: return ";" def visit_ParamList(self, n: c_ast.ParamList) -> str: return ", ".join(self.visit(param) for param in n.params) def visit_Return(self, n: c_ast.Return) -> str: s = "return" if n.expr: s += " " + self.visit(n.expr) return s + ";" def visit_Break(self, n: c_ast.Break) -> str: return "break;" def visit_Continue(self, n: c_ast.Continue) -> str: return "continue;" def visit_TernaryOp(self, n: c_ast.TernaryOp) -> str: s = "(" + self._visit_expr(n.cond) + ") ? " s += "(" + self._visit_expr(n.iftrue) + ") : " s += "(" + self._visit_expr(n.iffalse) + ")" return s def visit_If(self, n: c_ast.If) -> str: s = "if (" if n.cond: s += self.visit(n.cond) s += ")\n" s += self._generate_stmt(n.iftrue, add_indent=True) if n.iffalse: s += self._make_indent() + "else\n" s += self._generate_stmt(n.iffalse, add_indent=True) return s def visit_For(self, n: c_ast.For) -> str: s = "for (" if n.init: s += self.visit(n.init) s += ";" if n.cond: s += " " + self.visit(n.cond) s += ";" if n.next: s += " " + self.visit(n.next) s += ")\n" s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_While(self, n: c_ast.While) -> str: s = "while (" if n.cond: s += self.visit(n.cond) s += ")\n" s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_DoWhile(self, n: c_ast.DoWhile) -> str: s = "do\n" s += self._generate_stmt(n.stmt, add_indent=True) s += self._make_indent() + "while (" if n.cond: s += self.visit(n.cond) s += ");" return s def visit_StaticAssert(self, n: c_ast.StaticAssert) -> str: s = "_Static_assert(" s += self.visit(n.cond) if n.message: s += "," s += self.visit(n.message) s += ")" return s def visit_Switch(self, n: c_ast.Switch) -> str: s = "switch (" + self.visit(n.cond) + ")\n" s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_Case(self, n: c_ast.Case) -> str: s = "case " + self.visit(n.expr) + ":\n" for stmt in n.stmts: s += self._generate_stmt(stmt, add_indent=True) return s def visit_Default(self, n: c_ast.Default) -> str: s = "default:\n" for stmt in n.stmts: s += self._generate_stmt(stmt, add_indent=True) return s def visit_Label(self, n: c_ast.Label) -> str: return n.name + ":\n" + self._generate_stmt(n.stmt) def visit_Goto(self, n: c_ast.Goto) -> str: return "goto " + n.name + ";" def visit_EllipsisParam(self, n: c_ast.EllipsisParam) -> str: return "..." def visit_Struct(self, n: c_ast.Struct) -> str: return self._generate_struct_union_enum(n, "struct") def visit_Typename(self, n: c_ast.Typename) -> str: return self._generate_type(n.type) def visit_Union(self, n: c_ast.Union) -> str: return self._generate_struct_union_enum(n, "union") def visit_NamedInitializer(self, n: c_ast.NamedInitializer) -> str: s = "" for name in n.name: if isinstance(name, c_ast.ID): s += "." + name.name else: s += "[" + self.visit(name) + "]" s += " = " + self._visit_expr(n.expr) return s def visit_FuncDecl(self, n: c_ast.FuncDecl) -> str: return self._generate_type(n) def visit_ArrayDecl(self, n: c_ast.ArrayDecl) -> str: return self._generate_type(n, emit_declname=False) def visit_TypeDecl(self, n: c_ast.TypeDecl) -> str: return self._generate_type(n, emit_declname=False) def visit_PtrDecl(self, n: c_ast.PtrDecl) -> str: return self._generate_type(n, emit_declname=False) def _generate_struct_union_enum( self, n: c_ast.Struct | c_ast.Union | c_ast.Enum, name: str ) -> str: """Generates code for structs, unions, and enums. name should be 'struct', 'union', or 'enum'. """ if name in ("struct", "union"): assert isinstance(n, (c_ast.Struct, c_ast.Union)) members = n.decls body_function = self._generate_struct_union_body else: assert name == "enum" assert isinstance(n, c_ast.Enum) members = None if n.values is None else n.values.enumerators body_function = self._generate_enum_body s = name + " " + (n.name or "") if members is not None: # None means no members # Empty sequence means an empty list of members s += "\n" s += self._make_indent() self.indent_level += 2 s += "{\n" s += body_function(members) self.indent_level -= 2 s += self._make_indent() + "}" return s def _generate_struct_union_body(self, members: List[c_ast.Node]) -> str: return "".join(self._generate_stmt(decl) for decl in members) def _generate_enum_body(self, members: List[c_ast.Enumerator]) -> str: # `[:-2] + '\n'` removes the final `,` from the enumerator list return "".join(self.visit(value) for value in members)[:-2] + "\n" def _generate_stmt(self, n: c_ast.Node, add_indent: bool = False) -> str: """Generation from a statement node. This method exists as a wrapper for individual visit_* methods to handle different treatment of some statements in this context. """ if add_indent: self.indent_level += 2 indent = self._make_indent() if add_indent: self.indent_level -= 2 match n: case ( c_ast.Decl() | c_ast.Assignment() | c_ast.Cast() | c_ast.UnaryOp() | c_ast.BinaryOp() | c_ast.TernaryOp() | c_ast.FuncCall() | c_ast.ArrayRef() | c_ast.StructRef() | c_ast.Constant() | c_ast.ID() | c_ast.Typedef() | c_ast.ExprList() ): # These can also appear in an expression context so no semicolon # is added to them automatically # return indent + self.visit(n) + ";\n" case c_ast.Compound(): # No extra indentation required before the opening brace of a # compound - because it consists of multiple lines it has to # compute its own indentation. # return self.visit(n) case c_ast.If(): return indent + self.visit(n) case _: return indent + self.visit(n) + "\n" def _generate_decl(self, n: c_ast.Decl) -> str: """Generation from a Decl node.""" s = "" if n.funcspec: s = " ".join(n.funcspec) + " " if n.storage: s += " ".join(n.storage) + " " if n.align: s += self.visit(n.align[0]) + " " s += self._generate_type(n.type) return s def _generate_type( self, n: c_ast.Node, modifiers: List[c_ast.Node] = [], emit_declname: bool = True, ) -> str: """Recursive generation from a type node. n is the type node. modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers encountered on the way down to a TypeDecl, to allow proper generation from it. """ # ~ print(n, modifiers) match n: case c_ast.TypeDecl(): s = "" if n.quals: s += " ".join(n.quals) + " " s += self.visit(n.type) nstr = n.declname if n.declname and emit_declname else "" # Resolve modifiers. # Wrap in parens to distinguish pointer to array and pointer to # function syntax. # for i, modifier in enumerate(modifiers): match modifier: case c_ast.ArrayDecl(): if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl): nstr = "(" + nstr + ")" nstr += "[" if modifier.dim_quals: nstr += " ".join(modifier.dim_quals) + " " if modifier.dim is not None: nstr += self.visit(modifier.dim) nstr += "]" case c_ast.FuncDecl(): if i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl): nstr = "(" + nstr + ")" args = ( self.visit(modifier.args) if modifier.args is not None else "" ) nstr += "(" + args + ")" case c_ast.PtrDecl(): if modifier.quals: quals = " ".join(modifier.quals) suffix = f" {nstr}" if nstr else "" nstr = f"* {quals}{suffix}" else: nstr = "*" + nstr if nstr: s += " " + nstr return s case c_ast.Decl(): return self._generate_decl(n.type) case c_ast.Typename(): return self._generate_type(n.type, emit_declname=emit_declname) case c_ast.IdentifierType(): return " ".join(n.names) + " " case c_ast.ArrayDecl() | c_ast.PtrDecl() | c_ast.FuncDecl(): return self._generate_type( n.type, modifiers + [n], emit_declname=emit_declname ) case _: return self.visit(n) def _parenthesize_if( self, n: c_ast.Node, condition: Callable[[c_ast.Node], bool] ) -> str: """Visits 'n' and returns its string representation, parenthesized if the condition function applied to the node returns True. """ s = self._visit_expr(n) if condition(n): return "(" + s + ")" else: return s def _parenthesize_unless_simple(self, n: c_ast.Node) -> str: """Common use case for _parenthesize_if""" return self._parenthesize_if(n, lambda d: not self._is_simple_node(d)) def _is_simple_node(self, n: c_ast.Node) -> bool: """Returns True for nodes that are "simple" - i.e. nodes that always have higher precedence than operators. """ return isinstance( n, (c_ast.Constant, c_ast.ID, c_ast.ArrayRef, c_ast.StructRef, c_ast.FuncCall), )