From 012131e035a021ff07faecc4238661c2f2a42e0b Mon Sep 17 00:00:00 2001 From: Benjamin Loison Date: Fri, 16 Dec 2022 04:45:30 +0100 Subject: [PATCH] Solve C warnings and support renaming outputs of functions --- src/ast_to_c.ml | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/ast_to_c.ml b/src/ast_to_c.ml index fb4384c..344a308 100644 --- a/src/ast_to_c.ml +++ b/src/ast_to_c.ml @@ -189,9 +189,21 @@ let pp_expression node_name = in pp_expression_aux +(* deterministic *) +let nodes_outputs = Hashtbl.create Config.maxvar;; + +let prepend_output_aux node_name name = + "output_" ^ node_name ^ "_" ^ name + +let prepend_output output node_name = + match output with + | BVar name -> BVar (prepend_output_aux node_name name) + | IVar name -> IVar (prepend_output_aux node_name name) + | RVar name -> RVar (prepend_output_aux node_name name) + let rec pp_equations node_name fmt: t_eqlist -> unit = function | [] -> () - | ((l_types, vars), (EApp (r_types, node, exprs))) :: eqs when l_types <> [] -> Format.fprintf fmt "%a" (pp_equations node_name) ((([], []), (EApp (r_types, node, exprs))) :: eqs) + | ((l_types, vars), (EApp (r_types, node, exprs))) :: eqs when l_types <> [] -> Format.fprintf fmt "%a" (pp_equations node_name) ((([], []), (EApp (r_types, node, exprs))) :: ((l_types, vars), ((*Hashtbl.find nodes_outputs node*) ETuple (fst node.n_outputs, List.map (fun output -> EVar (fst node.n_outputs, prepend_output output node.n_name)) (snd node.n_outputs)))) :: eqs) | (([], []), (ETuple ([], []))) :: eqs -> Format.fprintf fmt "%a" (pp_equations node_name) eqs | ((l_type :: l_types, var :: vars), (ETuple (r_type :: r_types, expr :: exprs))) :: eqs -> Format.fprintf fmt "%a" (pp_equations node_name) ((([l_type], [var]), expr) :: ((l_types, vars), (ETuple (r_types, exprs))) :: eqs) | (([], []), expr) :: eqs -> @@ -209,7 +221,12 @@ let pp_resvars reset_expressions_counter = (* use the fact that any boolean and any integer can be encoded as a float, concerning integers [-2^(23+1) + 1; 2^(23+1) + 1] are correctly encoded (cf https://stackoverflow.com/a/53254438) *) Format.sprintf "float tmp_reset[%i], init[%i];" reset_expressions_counter reset_expressions_counter -let pp_return fmt +let pp_return node_name fmt outputs = + if node_name = "main" then + (Format.fprintf fmt "return %a;" + (pp_varlist Base) outputs) + else + Format.fprintf fmt "%s" (String.concat "\n\t" (List.map (fun output -> match output with | BVar name | IVar name | RVar name -> "output_" ^ node_name ^ "_" ^ name ^ " = " ^ name ^ ";") (snd outputs))) (* TODO: manage general outputs *) let pp_node fmt node = @@ -217,12 +234,13 @@ let pp_node fmt node = - `init_{NODE_NAME}` - `tmp_reset_{int}` - `init_{int}` - - `pre_{NODE_MAIN}_{VARIABLE}` *) + - `pre_{NODE_NAME}_{VARIABLE}` + - `output_{NODE_NAME}_{VARIABLE}` *) reset_expressions_counter := 0; let _ = (pp_equations node.n_name) Format.str_formatter node.n_equations in reset_expressions_counter := 0; (* could remove the `return` for functions other than the `main` one *) - Format.fprintf fmt "bool init_%s = true;\n\n%a\n\n%a\n\n%a\n\n%s\n\n%s %s(%a)\n{\n\t%a\n\n%a\n\n\tinit_%s = false;\n\n%a\n\n%a\n\n%a\n\n\treturn %a;\n}\n" + Format.fprintf fmt "bool init_%s = true;\n\n%a\n\n%a\n\n%a\n\n%s\n\n%s %s(%a)\n{\n\t%a\n\n\t%a\n\n%a\n\n\tinit_%s = false;\n\n%a\n\n%a\n\n%a\n\n\t%a\n}\n" node.n_name (* could avoid declaring unused variables *) (pp_prevarlist node.n_name) node.n_inputs @@ -234,12 +252,13 @@ let pp_node fmt node = (* could avoid newlines if they aren't used to seperate statements *) (pp_varlist Arg) node.n_inputs (pp_varlist Dec) node.n_local_vars + (pp_varlist Dec) node.n_outputs (pp_equations node.n_name) node.n_equations node.n_name (pp_asnprevarlist node.n_name) node.n_inputs (pp_asnprevarlist node.n_name) node.n_local_vars (pp_asnprevarlist node.n_name) node.n_outputs - (Format.fprintf fmt "%a" (pp_varlist Base) node.n_outputs) + (pp_return node.n_name) node.n_outputs let rec pp_nodes fmt nodes = match nodes with @@ -247,24 +266,24 @@ let rec pp_nodes fmt nodes = | node :: nodes -> Format.fprintf fmt "%a\n%a" pp_node node pp_nodes nodes -let rec load_outputs_from_vars n_outputs = +let rec load_outputs_from_vars node_name n_outputs = match n_outputs with | [] -> () | BVar n_output :: n_outputs | IVar n_output :: n_outputs | RVar n_output :: n_outputs -> - (if (not (List.mem n_output !outputs)) then outputs := n_output :: !outputs;); load_outputs_from_vars n_outputs + (if (not (List.mem n_output !outputs)) then outputs := (node_name ^ "_" ^ n_output) :: !outputs;); load_outputs_from_vars node_name n_outputs let rec load_outputs_from_nodes nodes = match nodes with | [] -> () | node :: nodes -> - (load_outputs_from_vars (snd node.n_outputs)); load_outputs_from_nodes nodes + (if node.n_name <> "main" then (load_outputs_from_vars node.n_name (snd node.n_outputs)); Hashtbl.add nodes_outputs node.n_name (snd node.n_outputs)); load_outputs_from_nodes nodes let ast_to_c fmt prog = load_outputs_from_nodes prog; Format.fprintf fmt (* could verify that uses, possibly indirectly (cf `->` implementation), a boolean in the ast before including `` *) "#include \n\n%s\n\n%a" - ("float " ^ (String.concat ", " !outputs) ^ ";") pp_nodes prog + ("float " ^ (String.concat ", " (List.map (fun output -> "output_" ^ output) !outputs)) ^ ";") pp_nodes prog