diff --git a/src/main.ml b/src/main.ml index e7d4c44..5ab21c5 100644 --- a/src/main.ml +++ b/src/main.ml @@ -40,7 +40,8 @@ let exec_passes ast verbose debug passes f = let _ = (** Usage and argument parsing. *) let default_passes = - ["linearization_reset"; "automata_translation"; "remove_if"; + ["linearization_reset"; "automata_translation"; "remove_if"; + "linearization_merge"; "linearization_pre"; "linearization_tuples"; "linearization_app"; "ensure_assign_val"; "equations_ordering"; @@ -82,6 +83,7 @@ let _ = List.iter (fun (s, k) -> Hashtbl.add passes_table s k) [ ("remove_if", Passes.pass_if_removal); + ("linearization_merge", Passes.pass_merge_lin); ("linearization_tuples", Passes.pass_linearization_tuples); ("linearization_app", Passes.pass_linearization_app); ("linearization_pre", Passes.pass_linearization_pre); diff --git a/src/passes.ml b/src/passes.ml index d5a1100..bd05c18 100644 --- a/src/passes.ml +++ b/src/passes.ml @@ -6,6 +6,89 @@ open Utils +let pass_merge_lin verbose debug = + let varname_prefix = "_mergelin" in + let count = ref 0 in + let rec aux_expr vars expr toplevel = + match expr with + | EVar _ | EConst _ -> [], vars, expr + | EMonOp (t, op, e) -> + let eqs, vars, e = aux_expr vars e false in + eqs, vars, EMonOp (t, op, e) + | EBinOp (t, op, e, e') -> + let eqs, vars, e = aux_expr vars e false in + let eqs', vars, e' = aux_expr vars e' false in + eqs'@eqs, vars, EBinOp (t, op, e, e') + | EComp (t, op, e, e') -> + let eqs, vars, e = aux_expr vars e false in + let eqs', vars, e' = aux_expr vars e' false in + eqs'@eqs, vars, EComp (t, op, e, e') + | EReset (t, e, e') -> + let eqs, vars, e = aux_expr vars e false in + let eqs', vars, e' = aux_expr vars e' false in + eqs'@eqs, vars, EReset (t, e, e') + | ETuple (t, l) -> + let eqs, vars, l = List.fold_right + (fun e (eqs, vars, l) -> + let eqs', vars, e = aux_expr vars e false in + eqs' @ eqs, vars, (e :: l)) + l ([], vars, []) in + eqs, vars, ETuple (t, l) + | EApp (t, n, e) -> + let eqs, vars, e = aux_expr vars e false in + eqs, vars, EApp (t, n, e) + | ETriOp (_, TOp_if, _, _, _) -> + raise (PassExn "There should no longer be any condition.") + | EWhen (t, e, e') -> + let eqs, vars, e = aux_expr vars e false in + let eqs', vars, e' = aux_expr vars e' false in + eqs @ eqs', vars, EWhen (t, e, e') + | ETriOp (t, TOp_merge, c, e, e') -> + begin + if toplevel + then + begin + let eqs, vars, c = aux_expr vars c false in + let eqs', vars, e = aux_expr vars e false in + let eqs'', vars, e' = aux_expr vars e' false in + eqs@eqs'@eqs'', vars, ETriOp (t, TOp_merge, c, e, e') + end + else + begin + if List.length t = 1 + then + let newvar = Format.sprintf "%s%d" varname_prefix !count in + let newvar = + match List.hd t with + | TInt -> IVar newvar + | TBool -> BVar newvar + | TReal -> RVar newvar + in + let () = incr count in + let vars = (t @ (fst vars), newvar :: (snd vars)) in + let eqs, vars, c = aux_expr vars c false in + let eqs', vars, e = aux_expr vars e false in + let eqs'', vars, e' = aux_expr vars e' false in + ((t, [newvar]), ETriOp (t, TOp_merge, c, e, e')) :: eqs @ eqs' @ eqs'', vars, EVar (t, newvar) + else + raise (PassExn "Merges should only happened on unary expressions.") + end + end + in + let aux_merge_lin node = + let eqs, vars = + List.fold_left + (fun (eqs, vars) (patt, expr) -> + let eqs', vars, expr = aux_expr vars expr true in + (patt, expr) :: eqs' @ eqs, vars) + ([], node.n_local_vars) node.n_equations + in + Some { node with n_local_vars = vars; n_equations = eqs } + in + node_pass aux_merge_lin + + + (** [pass_if_removal] replaces the `if` construct with `when` and `merge` ones. * * [x1, ..., xn = if c then e_l else e_r;]