diff --git a/src/cprint.ml b/src/cprint.ml index d030c24..e1e25c6 100644 --- a/src/cprint.ml +++ b/src/cprint.ml @@ -163,7 +163,7 @@ and cp_expression fmt (expr, hloc) = Format.fprintf fmt "%sstate->%s[%d] = %a;\n" prefix arr idx cp_value (value, hloc) end - | CAssign (CVInput _, _) -> failwith "should not happened." + | CAssign (CVInput _, _) -> failwith "[cprint.ml] never assign an input." | CSeq (e, e') -> Format.fprintf fmt "%a%a" cp_expression (e, hloc) diff --git a/src/ctranslation.ml b/src/ctranslation.ml index 48af417..b9a7997 100644 --- a/src/ctranslation.ml +++ b/src/ctranslation.ml @@ -67,7 +67,8 @@ let rec equation_to_expression (node_st, node_sts, (vl, expr)) = vl in CApplication (node.n_name,i , al, vl, node_sts) - | IETuple _ -> failwith "[ctranslation.ml] linearization should have transformed you." + | IETuple _ -> failwith "[ctranslation.ml] linearization should have \ + transformed the tuples of the right members." | IEWhen (expr, cond) -> begin CIf (iexpression_to_cvalue cond, diff --git a/src/main.ml b/src/main.ml index 4a46502..83d407f 100644 --- a/src/main.ml +++ b/src/main.ml @@ -25,7 +25,7 @@ let exec_passes ast verbose debug passes f = let _ = (** Usage and argument parsing. *) - let default_passes = ["linearization_pre"; "linearization_tuples"; "linearization_app"; + let default_passes = ["remove_if"; "linearization_pre"; "linearization_tuples"; "linearization_app"; "equations_ordering"] in let sanity_passes = ["chkvar_init_unicity"; "check_typing"] in let usage_msg = @@ -63,6 +63,7 @@ let _ = let passes_table = Hashtbl.create 100 in List.iter (fun (s, k) -> Hashtbl.add passes_table s k) [ + ("remove_if", Passes.pass_if_removal); ("linearization_tuples", Passes.pass_linearization_tuples); ("linearization_app", Passes.pass_linearization_app); ("linearization_pre", Passes.pass_linearization_pre); diff --git a/src/passes.ml b/src/passes.ml index 4a47113..352bb65 100644 --- a/src/passes.ml +++ b/src/passes.ml @@ -4,21 +4,156 @@ open Ast open Passes_utils open Utils -let rec split_tuple (eq: t_equation): t_eqlist = - let patt, expr = eq in - match expr with - | ETuple (_, expr_h :: expr_t) -> - begin - let t_l = type_exp expr_h in - let patt_l, patt_r = list_select (List.length t_l) (snd patt) in - let t_r = List.flatten (List.map type_var patt_r) in - ((t_l, patt_l), expr_h) :: - split_tuple ((t_r, patt_r), ETuple (t_r, expr_t)) - end - | ETuple (_, []) -> [] - | _ -> [eq] +(** [pass_if_removal] replaces the `if` construct with `when` and `merge` ones. + * + * [x1, ..., xn = if c then e_l else e_r;] + * is replaced by: + * (t1, ..., tn) = e_l; + * (u1, ..., un) = e_r; + * (v1, ..., vn) = (t1, ..., tn) when c; + * (w1, ..., wn) = (u1, ..., un) when (not c); + * (x1, ..., xn) = merge c (v1, ..., vn) (w1, ..., wn); + * + * Note that the first two equations (before the use of when) is required in + * order to have the expressions active at each step. + *) +let pass_if_removal verbose debug = + let varcount = ref 0 in + let make_patt t: t_varlist = + (t, List.fold_right + (fun ty acc -> + let nvar: ident = Format.sprintf "_ifrem%d" !varcount in + let nvar = + match ty with + | TInt -> IVar nvar + | TReal -> RVar nvar + | TBool -> BVar nvar + in + incr varcount; + nvar :: acc) + t []) + in + let simplify_tuple t = + match t with + | ETuple (t, [elt]) -> elt + | _ -> t + in + let rec aux_eq vars eq: t_eqlist * t_varlist * t_equation = + let patt, expr = eq in + match expr with + | EConst _ | EVar _ -> [], vars, eq + | EMonOp (t, op, e) -> + let eqs, vars, (patt, e) = aux_eq vars (patt, e) in + eqs, vars, (patt, EMonOp (t, op, e)) + | EBinOp (t, op, e, e') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + eqs @ eqs', vars, (patt, EBinOp (t, op, e, e')) + | ETriOp (t, TOp_if, e, e', e'') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + let eqs'', vars, (_, e'') = aux_eq vars (patt, e'') in + let patt_l: t_varlist = make_patt t in + let patt_r: t_varlist = make_patt t in + let patt_l_when: t_varlist = make_patt t in + let patt_r_when: t_varlist = make_patt t in + let expr_l: t_expression = + simplify_tuple + (ETuple + (fst patt_l, List.map (fun v -> EVar (type_var v, v)) (snd patt_l))) + in + let expr_r: t_expression = + simplify_tuple + (ETuple + (fst patt_r, List.map (fun v -> EVar (type_var v, v)) (snd patt_r))) + in + let expr_l_when: t_expression = + simplify_tuple + (ETuple + (fst patt_l_when, List.map (fun v -> EVar (type_var v, v)) + (snd patt_l_when))) + in + let expr_r_when: t_expression = + simplify_tuple + (ETuple + (fst patt_r_when, List.map (fun v -> EVar (type_var v, v)) + (snd patt_r_when))) + in + let equations: t_eqlist = + [(patt_l, e'); + (patt_r, e''); + (patt_l_when, + EWhen (t, expr_l, e)); + (patt_r_when, + EWhen (t, + expr_r, + (EMonOp (type_exp e, MOp_not, e))))] + @ eqs @ eqs' @eqs'' in + let vars: t_varlist = + varlist_concat + vars + (varlist_concat patt_l_when (varlist_concat patt_r_when + (varlist_concat patt_r patt_l))) in + let expr = + ETriOp (t, TOp_merge, e, expr_l_when, expr_r_when) in + equations, vars, (patt, expr) + | ETriOp (t, op, e, e', e'') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + let eqs'', vars, (_, e'') = aux_eq vars (patt, e'') in + eqs @ eqs' @ eqs'', vars, (patt, ETriOp (t, op, e, e', e'')) + | EComp (t, op, e, e') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + eqs @ eqs', vars, (patt, EComp (t, op, e, e')) + | EWhen (t, e, e') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + eqs @ eqs', vars, (patt, EWhen (t, e, e')) + | EReset (t, e, e') -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + let eqs', vars, (_, e') = aux_eq vars (patt, e') in + eqs @ eqs', vars, (patt, EReset (t, e, e')) + | ETuple (t, l) -> + let eqs, vars, l, _ = + List.fold_right + (fun e (eqs, vars, l, remaining_patt) -> + let patt_l, patt_r = split_patt remaining_patt e in + let eqs', vars, (_, e) = aux_eq vars (patt_l, e) in + eqs' @ eqs, vars, (e :: l), patt_r) + l ([], vars, [], patt) in + eqs, vars, (patt, ETuple (t, l)) + | EApp (t, n, e) -> + let eqs, vars, (_, e) = aux_eq vars (patt, e) in + eqs, vars, (patt, EApp (t, n, e)) + in + let aux_if_removal node = + let new_equations, new_locvars = + List.fold_left + (fun (eqs, vars) eq -> + let eqs', vars, eq = aux_eq vars eq in + eq :: eqs' @ eqs, vars) + ([], node.n_local_vars) node.n_equations + in + Some { node with n_equations = new_equations; n_local_vars = new_locvars } + in + node_pass aux_if_removal -let pass_linearization_tuples verbose debug = +let pass_linearization_tuples verbose debug ast = + let rec split_tuple (eq: t_equation): t_eqlist = + let patt, expr = eq in + match expr with + | ETuple (_, expr_h :: expr_t) -> + begin + let t_l = type_exp expr_h in + let patt_l, patt_r = list_select (List.length t_l) (snd patt) in + let t_r = List.flatten (List.map type_var patt_r) in + ((t_l, patt_l), expr_h) :: + split_tuple ((t_r, patt_r), ETuple (t_r, expr_t)) + end + | ETuple (_, []) -> [] + | _ -> [eq] + in let aux_linearization_tuples node = let new_equations = List.flatten (List.map @@ -29,11 +164,27 @@ let pass_linearization_tuples verbose debug = List.map (fun (patt, expr) -> (patt, EWhen (type_exp expr, expr, e'))) (split_tuple (fst eq, ETuple (t, l))) + | ETriOp (t, TOp_merge, c, ETuple (_, l), ETuple (_, l')) -> + begin + if List.length l <> List.length l' + || List.length t <> List.length (snd (fst eq)) + then raise (PassExn "Error while merging tuples.") + else + fst + (List.fold_left2 + (fun (eqs, remaining_patt) el er -> + let patt, remaining_patt = split_patt remaining_patt el in + let t = type_exp el in + (patt, ETriOp (t, TOp_merge, c, el, er)) + :: eqs, remaining_patt) + ([], fst eq) l l') + end | _ -> [eq]) node.n_equations) in Some { node with n_equations = new_equations } in - node_pass aux_linearization_tuples + try node_pass aux_linearization_tuples ast with + | PassExn err -> (debug err; None) let pass_linearization_app verbose debug = let applin_count = ref 0 in diff --git a/src/test.node b/src/test.node index faa8355..6dd535f 100644 --- a/src/test.node +++ b/src/test.node @@ -1,10 +1,10 @@ -node my_and (a, b: bool) returns (o: bool); +node test_merge_tuples (a, b: bool) returns (o: bool); +var t: bool; let - o = a and b; + (o, t) = if a and b then (true, false) else (false, true); tel -node n (i: int) returns (o: int); -var v: bool; -let - (o, v) = (1, my_and (pre o = 8, pre v)); -tel +--node my_and (a, b: bool) returns (o: bool); +--let +-- o = if a then b else false; +--tel diff --git a/src/utils.ml b/src/utils.ml index 01312ef..8005e46 100644 --- a/src/utils.ml +++ b/src/utils.ml @@ -104,3 +104,9 @@ let rec vars_of_expr (expr: t_expression) : ident list = let rec varlist_concat (l1: t_varlist) (l2: t_varlist): t_varlist = (fst l1 @ fst l2, snd l1 @ snd l2) +let split_patt (patt: t_varlist) (e: t_expression): t_varlist * t_varlist = + let pl, pr = list_select (List.length (type_exp e)) (snd patt) in + let tl = List.flatten (List.map type_var pl) in + let tr = List.flatten (List.map type_var pr) in + (tl, pl), (tr, pr) +