From ad97c6b627174109a1fcd9b8af3c26b148d58e77 Mon Sep 17 00:00:00 2001 From: Antoine Grimod Date: Mon, 9 Jan 2023 20:57:19 +0100 Subject: [PATCH] first version of clock unification --- src/ast.ml | 1 + src/main.ml | 6 +- src/passes.ml | 196 +++++++++++++++++++++++++++++--------------------- src/test.node | 6 ++ 4 files changed, 124 insertions(+), 85 deletions(-) diff --git a/src/ast.ml b/src/ast.ml index 039a0cf..01183a8 100644 --- a/src/ast.ml +++ b/src/ast.ml @@ -77,3 +77,4 @@ type t_ck = base_ck list and base_ck = | Base | On of base_ck * t_expression + | Unknown diff --git a/src/main.ml b/src/main.ml index c8c3fdb..2c93b78 100644 --- a/src/main.ml +++ b/src/main.ml @@ -25,8 +25,6 @@ let exec_passes ast verbose debug passes f = | [] -> f ast | (n, p) :: passes -> verbose (Format.asprintf "Executing pass %s:\n" n); - try - begin match p verbose debug ast with | None -> (exit_error ("Error while in the pass "^n^".\n"); exit 0) | Some ast -> ( @@ -34,8 +32,6 @@ let exec_passes ast verbose debug passes f = (Format.asprintf "Current AST (after %s):\n%a\n" n Lustre_pp.pp_ast ast); aux ast passes) - end with - | _ -> failwith ("The pass "^n^" should have caught me!") in aux ast passes @@ -47,7 +43,7 @@ let _ = ["linearization_reset"; "automata_translation"; "remove_if"; "linearization_pre"; "linearization_tuples"; "linearization_app"; "ensure_assign_val"; - "equations_ordering"] in + "equations_ordering"; "clock_unification"] in let sanity_passes = ["sanity_pass_assignment_unicity"; "check_typing"] in let usage_msg = "Usage: main [-passes p1,...,pn] [-ast] [-verbose] [-debug] \ diff --git a/src/passes.ml b/src/passes.ml index d43c7bf..ce36558 100644 --- a/src/passes.ml +++ b/src/passes.ml @@ -933,101 +933,137 @@ let clock_unification_pass verbose debug ast = let failure str = raise (PassExn ("Failed to unify clocks: "^str)) in let known_clocks = Hashtbl.create 100 in + let used = Hashtbl.create 100 in (*keep track of variables that appear on right side of equation*) + let changed = ref false in - let find_clock_var var = - match Hashtbl.find_opt known_clocks var with - | None -> - begin - match var with - | BVar(name) - | IVar(name) - | RVar(name) -> raise (PassExn ("Cannot find clock of variable "^name) ) - end - | Some c -> c + let rec count_not e acc = match e with + | EVar([TBool], var) -> acc, e + | EConst([TBool], cons) -> acc, e + | EMonOp([TBool], MOp_not, e) -> count_not e (acc + 1) + | _ -> raise (PassExn "verify_when failure") in - let rec compute_clock_exp exp = match exp with - | EConst(_, _) -> [Base] - | EVar(_, var) -> find_clock_var var - | EMonOp(_, MOp_pre, _) -> [Base] - | EMonOp(_, _, e) -> compute_clock_exp e + let verify_when e1 e2 = + let n1, var1 = count_not e1 0 + and n2, var2 = count_not e2 0 in + if n1 mod 2 <> n2 mod 2 || var1 <> var2 then + raise (PassExn "clock unification failure") + in + + let get_var_name var = match var with + | RVar(name) + | BVar(name) + | IVar(name) -> name + in + + let rec clk_to_string clk = match clk with + | Base -> "Base" + | Unknown -> "Unknown" + | On(clk, exp) -> + let n, var = count_not exp 0 in + let s = if n mod 2 = 1 then "not " else "" in + let v = match var with |EVar(_, var) -> get_var_name var | EConst(_, CBool(false)) -> "false" |_ -> "true" in + (clk_to_string clk) ^ " on " ^ s ^ v + in + + let add_clock var clock = + match Hashtbl.find known_clocks var with + | Unknown -> changed := true; (debug ("Found clock for "^(get_var_name var)^": "^(clk_to_string clock))); Hashtbl.replace known_clocks var clock + | c when c = clock -> () + | c -> raise (PassExn ("Clock conflict "^(get_var_name var) ^" "^(clk_to_string c) ^ " " ^ (clk_to_string clock))) + in + + let rec update_clock exp clk = match exp with + | EConst(_, _) -> () + | EVar(_, var) -> add_clock var clk; Hashtbl.replace used var var + | EMonOp(_, _, e) -> update_clock e clk | EComp(_, _, e1, e2) | EReset(_, e1, e2) - | EBinOp(_, _, e1, e2) -> - begin - let c1 = compute_clock_exp e1 - and c2 = compute_clock_exp e2 in - if c1 <> c2 then - failure "Binop" - else - c1 - end - | EWhen(_, e1, e2) -> - begin - match compute_clock_exp e1 with - | [c1] -> [On (c1, e2)] - | _ -> failure "When" - end + | EBinOp(_, _, e1, e2) -> update_clock e1 clk; update_clock e2 clk | ETriOp(_, TOp_merge, e1, e2, e3) -> - begin - let c1 = compute_clock_exp e1 - and c2 = compute_clock_exp e2 - and c3 = compute_clock_exp e3 in - match c1, c2, c3 with - | [c1], [On(cl2, e2)], [On(cl3, e3)] -> - begin - if cl2 <> c1 || cl3 <> c1 then - failure "Triop clocks" - else match e2, e3 with - | EMonOp(_, MOp_not, e), _ when e = e3 -> [c1] - | _, EMonOp(_, MOp_not, e) when e = e2 -> [c1] - | _ -> failure "Triop condition" - end - | _ -> failure ("Merge format") - end + update_clock e1 clk; + update_clock e2 (On(clk, e1)); + update_clock e3 (On(clk, EMonOp([TBool], MOp_not, e1))) | ETriOp(_, TOp_if, e1, e2, e3) -> - let (* Unused: c1 = compute_clock_exp e1 - and*) c2 = compute_clock_exp e2 - and c3 = compute_clock_exp e3 in - if c2 <> c3 then - failure "If clocks" - else c2 - - | ETuple(_, explist) -> List.concat_map compute_clock_exp explist - | EApp(_, node, e) -> - let rec aux_app clk_list = match clk_list with - | [] -> raise (PassExn "Node called with no argument provided") - | [cl] -> cl - | t::q -> if t = (aux_app q) then t else failure "App diff clocks" - and mult_clk cl out_list = match out_list with - | [] -> [] - | t::q -> cl::(mult_clk cl q) - in - mult_clk (aux_app (compute_clock_exp e)) (snd node.n_outputs) + (* The 3 expressions should have the same clock *) + begin + update_clock e1 clk; + update_clock e2 clk; + update_clock e3 clk + end + | ETuple(_, explist) -> List.iter (fun e -> update_clock e clk) explist + | EApp(_, node, e) -> update_clock e clk + | EWhen(_, e1, e2) -> + match clk with + | On(clk2, e) -> verify_when e e2; update_clock e1 clk2 + | _ -> raise (PassExn "Clock unification failure: when") in - let rec compute_eq_clock eq = - let rec step vars clks = match vars, clks with - | [], [] -> () - | [], c::q -> raise (PassExn "Mismatch between clock size") - | v::t, c::q -> Hashtbl.replace known_clocks v [c]; step t q - | l, [] -> raise (PassExn "Mismatch between clock size") + let rec propagate_clock eqs = + let rec step ((ty, vars), exp)= match vars with + | [] -> () + | v::t -> let clk = Hashtbl.find known_clocks v in + begin + if clk <> Unknown then update_clock exp clk + else (); + step ((ty, t), exp) + end in - let (_, vars), exp = eq in - let clk = compute_clock_exp exp in - step vars clk + List.iter step eqs in + + let rec iter_til_stable eqs = + changed := false; + propagate_clock eqs; + if !changed then + iter_til_stable eqs + in + + let check_unification node = + let (_, node_inputs) = node.n_inputs in + let rec check_vars_aux acc = match acc with + | [(v, c)] -> if c = Unknown && (Hashtbl.mem used v) then raise (PassExn ("Clock unification failure: Unkwown clock for "^(get_var_name v))) else c + | (v, t)::q -> let c = check_vars_aux q in + if c <> t then raise (PassExn "Clock unification failure: Non homogeneous equation") else c + | [] -> raise (PassExn "Clock unification failure: empty equation") + in + let rec check_vars ((ty, vars), exp) acc = match vars with + | [] -> let _ = check_vars_aux acc in () + | v::t -> check_vars ((ty, t), exp) ((v, Hashtbl.find known_clocks v)::acc) + in + let rec check_inputs inputs = match inputs with + | [] -> () + | i::q -> let c = Hashtbl.find known_clocks i in + match c with + | On(_, e) -> let _, var = count_not e 0 in + begin + match var with + | EConst(_, _) -> () + | EVar(_, var) -> if not (List.mem var node_inputs) then raise (PassExn "Clock unification failure: input clock depends on non input clock") + else check_inputs q + end + | _ -> check_inputs q + in + (*Check that all variables used have a clock + and that inputs clocks do not depend on local vars or outputs*) + List.iter (fun eq -> check_vars eq []) node.n_equations; + check_inputs node_inputs; + in + let compute_clock_node n = begin Hashtbl.clear known_clocks; - List.iter (fun v -> Hashtbl.replace known_clocks v [Base]) ( - snd n.n_inputs); (* Initializing inputs to base clock *) - List.iter compute_eq_clock n.n_equations; - if not (List.for_all (fun v -> (Hashtbl.find known_clocks v) = [Base]) ( - snd n.n_outputs)) then failure "Outputs" (*Checking that the node's output are on base clock *) - else - Some n + List.iter (fun v -> Hashtbl.replace known_clocks v Unknown) ( + snd n.n_inputs); (* Initializing inputs to Unknown clock *) + List.iter (fun v -> Hashtbl.replace known_clocks v Unknown) ( + snd n.n_local_vars); (* Initializing local variables to Unknown clock *) + List.iter (fun v -> Hashtbl.replace known_clocks v Base) ( + snd n.n_outputs); (* Initializing outputs to base clock *) + iter_til_stable n.n_equations; + (* catch potential errors and test for unification *) + check_unification n; + Some n end in node_pass compute_clock_node ast diff --git a/src/test.node b/src/test.node index 6f83475..eb2608d 100644 --- a/src/test.node +++ b/src/test.node @@ -16,3 +16,9 @@ let tmp = aux (a+b, i); tel +node test (u, v: int; c: bool) returns (o: int); +var x, y: int; b: bool; +let + o = 2 * (merge c u v); +tel +