diff --git a/src/ast_to_c.ml b/src/ast_to_c.ml index 0390814..77017a9 100644 --- a/src/ast_to_c.ml +++ b/src/ast_to_c.ml @@ -84,6 +84,8 @@ let rec pp_asnprevarlist node_name fmt : t_varlist -> unit = function let reset_expressions_counter = ref 0;; +let outputs = ref [];; + let pp_expression node_name = let rec pp_expression_aux fmt expression = let rec pp_expression_list fmt exprs = @@ -189,8 +191,13 @@ let pp_expression node_name = let rec pp_equations node_name fmt: t_eqlist -> unit = function | [] -> () + | ((l_types, vars), (EApp (r_types, node, exprs))) :: eqs when l_types <> [] -> Format.fprintf fmt "%a" (pp_equations node_name) ((([], []), (EApp (r_types, node, exprs))) :: eqs) | (([], []), (ETuple ([], []))) :: eqs -> Format.fprintf fmt "%a" (pp_equations node_name) eqs | ((l_type :: l_types, var :: vars), (ETuple (r_type :: r_types, expr :: exprs))) :: eqs -> Format.fprintf fmt "%a" (pp_equations node_name) ((([l_type], [var]), expr) :: ((l_types, vars), (ETuple (r_types, exprs))) :: eqs) + | (([], []), expr) :: eqs -> + Format.fprintf fmt "\t%a;\n%a" + (pp_expression node_name) expr + (pp_equations node_name) eqs | (patt, expr) :: eqs -> Format.fprintf fmt "\t%a = %a;\n%a" (pp_varlist Base) patt @@ -212,19 +219,19 @@ let pp_node fmt node = reset_expressions_counter := 0; let _ = (pp_equations node.n_name) Format.str_formatter node.n_equations in reset_expressions_counter := 0; - Format.fprintf fmt "bool init_%s = true;\n\n%a\n\n%a\n\n%a\n\n%s\n\n%a %s(%a)\n{\n\t%a\n\n\t%a\n\n%a\n\tinit_%s = false;\n\n%a\n\n%a\n\n%a\n\n\treturn %a;\n}\n" + (* could remove the `return` for functions other than the `main` one *) + Format.fprintf fmt "bool init_%s = true;\n\n%a\n\n%a\n\n%a\n\n%s\n\n%s %s(%a)\n{\n\t%a\n\n%a\n\n\tinit_%s = false;\n\n%a\n\n%a\n\n%a\n\n\treturn %a;\n}\n" node.n_name (* could avoid declaring unused variables *) (pp_prevarlist node.n_name) node.n_inputs (pp_prevarlist node.n_name) node.n_local_vars (pp_prevarlist node.n_name) node.n_outputs (pp_resvars !reset_expressions_counter) - pp_retvarlist node.n_outputs + (if node.n_name = "main" then "int" else "void") node.n_name (* could avoid newlines if they aren't used to seperate statements *) (pp_varlist Arg) node.n_inputs (pp_varlist Dec) node.n_local_vars - (pp_varlist Dec) node.n_outputs (pp_equations node.n_name) node.n_equations node.n_name (pp_asnprevarlist node.n_name) node.n_inputs @@ -238,9 +245,24 @@ let rec pp_nodes fmt nodes = | node :: nodes -> Format.fprintf fmt "%a\n%a" pp_node node pp_nodes nodes +let rec load_outputs_from_vars n_outputs = + match n_outputs with + | [] -> () + | BVar n_output :: n_outputs + | IVar n_output :: n_outputs + | RVar n_output :: n_outputs -> + (if (not (List.mem n_output !outputs)) then outputs := n_output :: !outputs;); load_outputs_from_vars n_outputs + +let rec load_outputs_from_nodes nodes = + match nodes with + | [] -> () + | node :: nodes -> + (load_outputs_from_vars (snd node.n_outputs)); load_outputs_from_nodes nodes + let ast_to_c fmt prog = + load_outputs_from_nodes prog; Format.fprintf fmt (* could verify that uses, possibly indirectly (cf `->` implementation), a boolean in the ast before including `` *) - "#include \n\n%a" - pp_nodes prog + "#include \n\n%s\n\n%a" + ("float " ^ (String.concat ", " !outputs) ^ ";") pp_nodes prog