diff --git a/src/ast.ml b/src/ast.ml index fe087cf..c2fb0da 100644 --- a/src/ast.ml +++ b/src/ast.ml @@ -73,3 +73,7 @@ and t_node = type t_nodelist = t_node list +type t_ck = base_ck list +and base_ck = + | Base + | On of base_ck * t_expression diff --git a/src/main.ml b/src/main.ml index 3652564..e4e38b8 100644 --- a/src/main.ml +++ b/src/main.ml @@ -25,7 +25,7 @@ let exec_passes ast main_fn verbose debug passes f = let _ = (** Usage and argument parsing. *) - let default_passes = ["automata_validity" ;"automata_translation"; "linearization"; "pre2vars"; "equations_ordering"] in + let default_passes = ["automata_validity" ;"automata_translation"; "linearization"; "pre2vars"; "equations_ordering"; "clock_unification"] in let sanity_passes = ["chkvar_init_unicity"; "check_typing"] in let usage_msg = "Usage: main [-passes p1,...,pn] [-ast] [-verbose] [-debug] \ @@ -68,10 +68,11 @@ let _ = ("pre2vars", Passes.pre2vars); ("chkvar_init_unicity", Passes.chkvar_init_unicity); ("automata_translation", Passes.automata_translation_pass); - ("automata_validity", Passes.check_automata_validity); + ("automata_validity", Passes.check_automata_validity); ("linearization", Passes.pass_linearization); ("equations_ordering", Passes.pass_eq_reordering); ("check_typing", Passes.pass_typing); + ("clock_unification", Passes.clock_unification_pass); ]; (** Main functionality below *) diff --git a/src/passes.ml b/src/passes.ml index e8879d1..02b1501 100644 --- a/src/passes.ml +++ b/src/passes.ml @@ -89,7 +89,7 @@ let chkvar_init_unicity verbose debug main_fn : t_nodelist -> t_nodelist option let aux (node: t_node) : t_node option = let incr_aux h n = match Hashtbl.find_opt h n with - | None -> failwith "todo, should not happened." + | None -> raise (PassExn "todo, should not happened.") | Some num -> Hashtbl.replace h n (num + 1) in let incr_eq h (((_, patt), _): t_equation) = @@ -403,7 +403,7 @@ let check_automata_validity verbos debug main_fn = match init with | State(name, eqs, cond, next) -> init_left_side eqs; let validity = List.for_all (fun s -> (check_state s)) states in if not validity then - failwith "Automaton branch has different variables assignment in different branches" + raise (PassExn "Automaton branch has different variables assignment in different branches") end in let aux node = @@ -479,8 +479,21 @@ let automaton_translation debug automaton = ) in - let rec translate_var s v explist = match explist with - | [] -> EConst([TInt], CInt(0)) (* TODO *) + let default_constant ty = + let defaults ty = match ty with + | TInt -> EConst([ty], CInt(0)) + | TBool -> EConst([ty], CBool(false)) + | TReal -> EConst([ty], CReal(0.0)) + in + match ty with + | [TInt] -> EConst(ty, CInt(0)) + | [TBool] -> EConst(ty, CBool(false)) + | [TReal] -> EConst(ty, CReal(0.0)) + | _ -> ETuple(ty, List.map defaults ty) + in + + let rec translate_var s v explist ty = match explist with + | [] -> default_constant ty (* TODO *) | (state, exp)::q -> ETriOp(Utils.type_exp exp, TOp_if, EComp([TBool], COp_eq, @@ -488,7 +501,7 @@ let automaton_translation debug automaton = EConst([TInt], CInt(Hashtbl.find state_to_int state)) ), exp, - translate_var s v q + translate_var s v q ty ) in @@ -501,7 +514,7 @@ let automaton_translation debug automaton = init_state_translation states 1; let exp_transition = EBinOp([TInt], BOp_arrow, EConst([TInt], CInt(1)), EMonOp([TInt], MOp_pre, transition_eq states s)) in let new_equations = [(([TInt], [IVar(s)]), exp_transition)] in - Hashtbl.fold (fun var explist acc -> (var, translate_var s var explist)::acc) gathered new_equations, IVar(s) + Hashtbl.fold (fun var explist acc -> (var, translate_var s var explist (fst var))::acc) gathered new_equations, IVar(s) let automata_trans_pass debug (node:t_node) : t_node option= @@ -529,4 +542,106 @@ let automata_trans_pass debug (node:t_node) : t_node option= let automata_translation_pass verbose debug main_fn = node_pass (automata_trans_pass debug) +let clock_unification_pass verbose debug main_fn ast = + let failure str = raise (PassExn ("Failed to unify clocks: "^str)) in + + let known_clocks = Hashtbl.create 100 in + + let find_clock_var var = + match Hashtbl.find_opt known_clocks var with + | None -> + begin + match var with + | BVar(name) + | IVar(name) + | RVar(name) -> raise (PassExn ("Cannot find clock of variable "^name) ) + end + | Some c -> c + in + + let rec compute_clock_exp exp = match exp with + | EConst(_, _) -> [Base] + | EVar(_, var) -> find_clock_var var + | EMonOp(_, MOp_pre, _) -> [Base] + | EMonOp(_, _, e) -> compute_clock_exp e + + | EComp(_, _, e1, e2) + | EReset(_, e1, e2) + | EBinOp(_, _, e1, e2) -> + begin + let c1 = compute_clock_exp e1 + and c2 = compute_clock_exp e2 in + if c1 <> c2 then + failure "Binop" + else + c1 + end + | EWhen(_, e1, e2) -> + begin + match compute_clock_exp e1 with + | [c1] -> [On (c1, e2)] + | _ -> failure "When" + end + | ETriOp(_, TOp_merge, e1, e2, e3) -> + begin + let c1 = compute_clock_exp e1 + and c2 = compute_clock_exp e2 + and c3 = compute_clock_exp e3 in + match c1, c2, c3 with + | [c1], [On(cl2, e2)], [On(cl3, e3)] -> + begin + if cl2 <> c1 || cl3 <> c1 then + failure "Triop clocks" + else match e2, e3 with + | EMonOp(_, MOp_not, e), _ when e = e3 -> [c1] + | _, EMonOp(_, MOp_not, e) when e = e2 -> [c1] + | _ -> failure "Triop condition" + end + | _ -> failure ("Merge format") + end + | ETriOp(_, TOp_if, e1, e2, e3) -> + let c1 = compute_clock_exp e1 + and c2 = compute_clock_exp e2 + and c3 = compute_clock_exp e3 in + if c2 <> c3 then + failure "If clocks" + else c2 + + | ETuple(_, explist) -> List.concat_map compute_clock_exp explist + | EApp(_, node, e) -> + let rec aux_app clk_list = match clk_list with + | [] -> raise (PassExn "Node called with no argument provided") + | [cl] -> cl + | t::q -> if t = (aux_app q) then t else failure "App diff clocks" + and mult_clk cl out_list = match out_list with + | [] -> [] + | t::q -> cl::(mult_clk cl q) + in + mult_clk (aux_app (compute_clock_exp e)) (snd node.n_outputs) + in + + let rec compute_eq_clock eq = + let rec step vars clks = match vars, clks with + | [], [] -> () + | [], c::q -> raise (PassExn "Mismatch between clock size") + | v::t, c::q -> Hashtbl.replace known_clocks v [c]; step t q + | l, [] -> raise (PassExn "Mismatch between clock size") + in + let (_, vars), exp = eq in + let clk = compute_clock_exp exp in + step vars clk + in + + let compute_clock_node n = + begin + Hashtbl.clear known_clocks; + List.iter (fun v -> Hashtbl.replace known_clocks v [Base]) ( + snd n.n_inputs); (* Initializing inputs to base clock *) + List.iter compute_eq_clock n.n_equations; + if not (List.for_all (fun v -> (Hashtbl.find known_clocks v) = [Base]) ( + snd n.n_outputs)) then failure "Outputs" (*Checking that the node's output are on base clock *) + else + Some n + end + in node_pass compute_clock_node ast