[passes] removal of constructs: seems ok

This commit is contained in:
Arnaud DABY-SEESARAM 2022-12-19 11:22:16 +01:00
parent c52dce6c02
commit 4ff193759b
6 changed files with 184 additions and 25 deletions

View File

@ -163,7 +163,7 @@ and cp_expression fmt (expr, hloc) =
Format.fprintf fmt "%sstate->%s[%d] = %a;\n"
prefix arr idx cp_value (value, hloc)
end
| CAssign (CVInput _, _) -> failwith "should not happened."
| CAssign (CVInput _, _) -> failwith "[cprint.ml] never assign an input."
| CSeq (e, e') ->
Format.fprintf fmt "%a%a"
cp_expression (e, hloc)

View File

@ -67,7 +67,8 @@ let rec equation_to_expression (node_st, node_sts, (vl, expr)) =
vl
in
CApplication (node.n_name,i , al, vl, node_sts)
| IETuple _ -> failwith "[ctranslation.ml] linearization should have transformed you."
| IETuple _ -> failwith "[ctranslation.ml] linearization should have \
transformed the tuples of the right members."
| IEWhen (expr, cond) ->
begin
CIf (iexpression_to_cvalue cond,

View File

@ -25,7 +25,7 @@ let exec_passes ast verbose debug passes f =
let _ =
(** Usage and argument parsing. *)
let default_passes = ["linearization_pre"; "linearization_tuples"; "linearization_app";
let default_passes = ["remove_if"; "linearization_pre"; "linearization_tuples"; "linearization_app";
"equations_ordering"] in
let sanity_passes = ["chkvar_init_unicity"; "check_typing"] in
let usage_msg =
@ -63,6 +63,7 @@ let _ =
let passes_table = Hashtbl.create 100 in
List.iter (fun (s, k) -> Hashtbl.add passes_table s k)
[
("remove_if", Passes.pass_if_removal);
("linearization_tuples", Passes.pass_linearization_tuples);
("linearization_app", Passes.pass_linearization_app);
("linearization_pre", Passes.pass_linearization_pre);

View File

@ -4,21 +4,156 @@ open Ast
open Passes_utils
open Utils
let rec split_tuple (eq: t_equation): t_eqlist =
let patt, expr = eq in
match expr with
| ETuple (_, expr_h :: expr_t) ->
begin
let t_l = type_exp expr_h in
let patt_l, patt_r = list_select (List.length t_l) (snd patt) in
let t_r = List.flatten (List.map type_var patt_r) in
((t_l, patt_l), expr_h) ::
split_tuple ((t_r, patt_r), ETuple (t_r, expr_t))
end
| ETuple (_, []) -> []
| _ -> [eq]
(** [pass_if_removal] replaces the `if` construct with `when` and `merge` ones.
*
* [x1, ..., xn = if c then e_l else e_r;]
* is replaced by:
* (t1, ..., tn) = e_l;
* (u1, ..., un) = e_r;
* (v1, ..., vn) = (t1, ..., tn) when c;
* (w1, ..., wn) = (u1, ..., un) when (not c);
* (x1, ..., xn) = merge c (v1, ..., vn) (w1, ..., wn);
*
* Note that the first two equations (before the use of when) is required in
* order to have the expressions active at each step.
*)
let pass_if_removal verbose debug =
let varcount = ref 0 in
let make_patt t: t_varlist =
(t, List.fold_right
(fun ty acc ->
let nvar: ident = Format.sprintf "_ifrem%d" !varcount in
let nvar =
match ty with
| TInt -> IVar nvar
| TReal -> RVar nvar
| TBool -> BVar nvar
in
incr varcount;
nvar :: acc)
t [])
in
let simplify_tuple t =
match t with
| ETuple (t, [elt]) -> elt
| _ -> t
in
let rec aux_eq vars eq: t_eqlist * t_varlist * t_equation =
let patt, expr = eq in
match expr with
| EConst _ | EVar _ -> [], vars, eq
| EMonOp (t, op, e) ->
let eqs, vars, (patt, e) = aux_eq vars (patt, e) in
eqs, vars, (patt, EMonOp (t, op, e))
| EBinOp (t, op, e, e') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
eqs @ eqs', vars, (patt, EBinOp (t, op, e, e'))
| ETriOp (t, TOp_if, e, e', e'') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
let eqs'', vars, (_, e'') = aux_eq vars (patt, e'') in
let patt_l: t_varlist = make_patt t in
let patt_r: t_varlist = make_patt t in
let patt_l_when: t_varlist = make_patt t in
let patt_r_when: t_varlist = make_patt t in
let expr_l: t_expression =
simplify_tuple
(ETuple
(fst patt_l, List.map (fun v -> EVar (type_var v, v)) (snd patt_l)))
in
let expr_r: t_expression =
simplify_tuple
(ETuple
(fst patt_r, List.map (fun v -> EVar (type_var v, v)) (snd patt_r)))
in
let expr_l_when: t_expression =
simplify_tuple
(ETuple
(fst patt_l_when, List.map (fun v -> EVar (type_var v, v))
(snd patt_l_when)))
in
let expr_r_when: t_expression =
simplify_tuple
(ETuple
(fst patt_r_when, List.map (fun v -> EVar (type_var v, v))
(snd patt_r_when)))
in
let equations: t_eqlist =
[(patt_l, e');
(patt_r, e'');
(patt_l_when,
EWhen (t, expr_l, e));
(patt_r_when,
EWhen (t,
expr_r,
(EMonOp (type_exp e, MOp_not, e))))]
@ eqs @ eqs' @eqs'' in
let vars: t_varlist =
varlist_concat
vars
(varlist_concat patt_l_when (varlist_concat patt_r_when
(varlist_concat patt_r patt_l))) in
let expr =
ETriOp (t, TOp_merge, e, expr_l_when, expr_r_when) in
equations, vars, (patt, expr)
| ETriOp (t, op, e, e', e'') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
let eqs'', vars, (_, e'') = aux_eq vars (patt, e'') in
eqs @ eqs' @ eqs'', vars, (patt, ETriOp (t, op, e, e', e''))
| EComp (t, op, e, e') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
eqs @ eqs', vars, (patt, EComp (t, op, e, e'))
| EWhen (t, e, e') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
eqs @ eqs', vars, (patt, EWhen (t, e, e'))
| EReset (t, e, e') ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
let eqs', vars, (_, e') = aux_eq vars (patt, e') in
eqs @ eqs', vars, (patt, EReset (t, e, e'))
| ETuple (t, l) ->
let eqs, vars, l, _ =
List.fold_right
(fun e (eqs, vars, l, remaining_patt) ->
let patt_l, patt_r = split_patt remaining_patt e in
let eqs', vars, (_, e) = aux_eq vars (patt_l, e) in
eqs' @ eqs, vars, (e :: l), patt_r)
l ([], vars, [], patt) in
eqs, vars, (patt, ETuple (t, l))
| EApp (t, n, e) ->
let eqs, vars, (_, e) = aux_eq vars (patt, e) in
eqs, vars, (patt, EApp (t, n, e))
in
let aux_if_removal node =
let new_equations, new_locvars =
List.fold_left
(fun (eqs, vars) eq ->
let eqs', vars, eq = aux_eq vars eq in
eq :: eqs' @ eqs, vars)
([], node.n_local_vars) node.n_equations
in
Some { node with n_equations = new_equations; n_local_vars = new_locvars }
in
node_pass aux_if_removal
let pass_linearization_tuples verbose debug =
let pass_linearization_tuples verbose debug ast =
let rec split_tuple (eq: t_equation): t_eqlist =
let patt, expr = eq in
match expr with
| ETuple (_, expr_h :: expr_t) ->
begin
let t_l = type_exp expr_h in
let patt_l, patt_r = list_select (List.length t_l) (snd patt) in
let t_r = List.flatten (List.map type_var patt_r) in
((t_l, patt_l), expr_h) ::
split_tuple ((t_r, patt_r), ETuple (t_r, expr_t))
end
| ETuple (_, []) -> []
| _ -> [eq]
in
let aux_linearization_tuples node =
let new_equations = List.flatten
(List.map
@ -29,11 +164,27 @@ let pass_linearization_tuples verbose debug =
List.map
(fun (patt, expr) -> (patt, EWhen (type_exp expr, expr, e')))
(split_tuple (fst eq, ETuple (t, l)))
| ETriOp (t, TOp_merge, c, ETuple (_, l), ETuple (_, l')) ->
begin
if List.length l <> List.length l'
|| List.length t <> List.length (snd (fst eq))
then raise (PassExn "Error while merging tuples.")
else
fst
(List.fold_left2
(fun (eqs, remaining_patt) el er ->
let patt, remaining_patt = split_patt remaining_patt el in
let t = type_exp el in
(patt, ETriOp (t, TOp_merge, c, el, er))
:: eqs, remaining_patt)
([], fst eq) l l')
end
| _ -> [eq])
node.n_equations) in
Some { node with n_equations = new_equations }
in
node_pass aux_linearization_tuples
try node_pass aux_linearization_tuples ast with
| PassExn err -> (debug err; None)
let pass_linearization_app verbose debug =
let applin_count = ref 0 in

View File

@ -1,10 +1,10 @@
node my_and (a, b: bool) returns (o: bool);
node test_merge_tuples (a, b: bool) returns (o: bool);
var t: bool;
let
o = a and b;
(o, t) = if a and b then (true, false) else (false, true);
tel
node n (i: int) returns (o: int);
var v: bool;
let
(o, v) = (1, my_and (pre o = 8, pre v));
tel
--node my_and (a, b: bool) returns (o: bool);
--let
-- o = if a then b else false;
--tel

View File

@ -104,3 +104,9 @@ let rec vars_of_expr (expr: t_expression) : ident list =
let rec varlist_concat (l1: t_varlist) (l2: t_varlist): t_varlist =
(fst l1 @ fst l2, snd l1 @ snd l2)
let split_patt (patt: t_varlist) (e: t_expression): t_varlist * t_varlist =
let pl, pr = list_select (List.length (type_exp e)) (snd patt) in
let tl = List.flatten (List.map type_var pl) in
let tr = List.flatten (List.map type_var pr) in
(tl, pl), (tr, pr)