10000 remove Uvar constructor from detCheck by FissoreD · Pull Request #332 · LPCIC/elpi · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

remove Uvar constructor from detCheck #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 15 additions & 35 deletions src/compiler/determinacy_checker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type dtype =
| Exp of dtype list (** -> for kinds like list, int, string *)
| BVar of F.t (** -> in predicates like: std.exists or in parametric type abbreviations. *)
| Arrow of Mode.t * Structured.variadic * dtype * dtype (** -> abstractions *)
| UVar of dtype
| Any
[@@deriving show, ord]

Expand All @@ -36,33 +35,31 @@ module Good_call : sig
For example, see test90.elpi, test91.elpi, test92.elpi,
test93.elpi, and test94.elpi.
*)
type offending_term = { exp : dtype; found : dtype; term : ScopedTerm.t; p : bool }
type offending_term = { exp : dtype; found : dtype; term : ScopedTerm.t }
type t [@@deriving show]

val init : unit -> t
val make : exp:dtype -> found:dtype -> p:bool -> ScopedTerm.t -> t
val make : exp:dtype -> found:dtype -> ScopedTerm.t -> t
val is_good : t -> bool
val is_wrong : t -> bool
val get : t -> offending_term
val set : t -> t -> unit
val set_wrong : t -> exp:dtype -> found:dtype -> p:bool -> ScopedTerm.t -> unit
val set_wrong : t -> exp:dtype -> found:dtype -> ScopedTerm.t -> unit
val set_good : t -> unit
val is_polymorphic : t -> bool
end = struct
type offending_term = { exp : dtype; found : dtype; term : ScopedTerm.t; p : bool }
type offending_term = { exp : dtype; found : dtype; term : ScopedTerm.t }
type t = offending_term option ref

let init () : t = ref None
let make ~exp ~found ~p term : t = ref @@ Some { exp; found; term; p }
let make ~exp ~found term : t = ref @@ Some { exp; found; term }
let is_good (x : t) = Option.is_none !x
let is_wrong (x : t) = Option.is_some !x
let get (x : t) = Option.get !x
let set (t1 : t) (t2 : t) = t1 := !t2
let set_wrong (t1 : t) ~exp ~found ~p term = t1 := Some { exp; found; term; p }
let set_wrong (t1 : t) ~exp ~found term = t1 := Some { exp; found; term }
let set_good (t : t) = t := None
let show (x : t) = match !x with None -> "true" | Some e -> Format.asprintf "false (%a)" Loc.pp e.term.loc
let pp fmt x = Format.fprintf fmt "%s" (show x)
let is_polymorphic = function {contents = Some {p}} -> p | _ -> false
end

exception DetError of (Scope.t ScopedTerm.ty_name option * Good_call.t)
Expand All @@ -73,7 +70,6 @@ exception KError of (Scope.t ScopedTerm.ty_name option * Good_call.t)
exception LoadFlexClause of ScopedTerm.t

let rec pp_dtype fmt = function
| UVar t -> Format.fprintf fmt "·%a" pp_dtype t
| Det -> Format.fprintf fmt "Det"
| Rel -> Format.fprintf fmt "Rel"
| BVar b -> Format.fprintf fmt "BVar %a" F.pp b
Expand All @@ -85,8 +81,7 @@ let rec pp_dtype fmt = function
type t = (TypeAssignment.skema * Loc.t) F.Map.t [@@deriving show, ord]

let arr m ~v a b = Arrow (m, v, a, b)
let rec is_exp = function Exp _ -> true | UVar e -> is_exp e | _ -> false
let is_uvar = function UVar _ -> true | _ -> false
let is_exp = function Exp _ -> true | _ -> false
let choose_variadic v full right = if v = Structured.Variadic then full else right

module Compilation = struct
Expand All @@ -112,7 +107,7 @@ module Compilation = struct
an example can be found in tests/sources/findall.elpi
*)
| Arr (MRef _, v, l, r) -> arr ~v Output (type_ass_2func ~loc l) (type_ass_2func ~loc r)
| UVar a -> if MutableOnce.is_set a then UVar (type_ass_2func ~loc (TypeAssignment.deref a)) else BVar (MutableOnce.get_name a)
| UVar a -> if MutableOnce.is_set a then (type_ass_2func ~loc (TypeAssignment.deref a)) else BVar (MutableOnce.get_name a)
in
type_ass_2func ~loc t

Expand All @@ -139,12 +134,8 @@ module Aux = struct
if f1 = d1 || f2 = d1 then d1
else
match (f1, f2) with
| UVar a, UVar b -> UVar (min_max ~positive ~loc ~d1 ~d2 a b)
| UVar a, b | a, UVar b -> min_max ~positive ~loc ~d1 ~d2 a b
| Det, Det -> Det
| Rel, Rel -> Rel
| Exp [UVar a], b -> min_max ~positive ~loc ~d1 ~d2 (Exp [a]) b
| a, Exp [UVar b] -> min_max ~positive ~loc ~d1 ~d2 a (Exp [b])
| a, (Any | BVar _) | (Any | BVar _), a -> a
| Exp [ ((Det | Rel | Exp _) as x) ], (Det | Rel) -> min_max ~positive ~loc ~d1 ~d2 x f2
| (Det | Rel), Exp [ ((Det | Rel | Exp _) as x) ] -> min_max ~positive ~loc ~d1 ~d2 f1 x
Expand All @@ -166,7 +157,6 @@ module Aux = struct

let rec minimize_maximize ~loc ~d1 ~d2 d =
match d with
| UVar v -> UVar (minimize_maximize ~loc ~d1 ~d2 v)
| Det | Rel -> d1
| Exp l -> Exp (List.map (minimize_maximize ~loc ~d1 ~d2) l)
| Arrow (Input, v, l, r) ->
Expand All @@ -190,14 +180,11 @@ module Aux = struct
let rec choose_dir ~loc t1 t2 = function Mode.Input -> aux ~loc t2 t1 | Mode.Output -> aux ~loc t1 t2
and aux ~loc a b =
match (a, b) with
| UVar a, b | a, UVar b -> aux a b ~loc
| _, Any -> true
| Any, _ -> b = maximize ~loc b (* TC may accept A = any, so we do too *)
| BVar v1, BVar v2 -> F.equal v1 v2 || wrong_bvars ~loc v1 v2
| BVar _, _ | _, BVar _ -> wrong_type ~loc a b
| Exp l1, Exp l2 -> ( try List.for_all2 (aux ~loc) l1 l2 with Invalid_argument _ -> wrong_type ~loc a b)
| Exp [UVar a], b -> aux ~loc (Exp [a]) b
| a, Exp [UVar b] -> aux ~loc a (Exp [b])
| Exp [ ((Det | Rel | Exp _) as x) ], (Det | Rel) -> aux ~loc x b
| (Det | Rel), Exp [ ((Det | Rel | Exp _) as x) ] -> aux ~loc a x
| Arrow (m1, NotVariadic, l1, r1), Arrow (_, NotVariadic, l2, r2) -> choose_dir ~loc l1 l2 m1 && aux r1 r2 ~loc
Expand Down Expand Up @@ -326,7 +313,7 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
incr cnt;
F.from_string ("*dummy" ^ string_of_int !cnt)
in
let rec get_tl = function Arrow (_, _, _, r) -> get_tl r | UVar e -> get_tl e | e -> e in
let rec get_tl = function Arrow (_, _, _, r) -> get_tl r | e -> e in

let is_cut ScopedTerm.{ it } = match it with Const b -> is_global b S.cut | _ -> false in

Expand All @@ -344,7 +331,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
Format.eprintf "@[<hov 2>In recursive call for infer.aux with head-term@ @[%a@], good_call is %a -- and dtype@ @[%a@] and user_dytpe is @[%a@]@]@." (Format.pp_print_option ScopedTerm.pretty) (List.nth_opt tl 0)
Good_call.pp b pp_dtype d (Format.pp_print_option pp_dtype) user_dtype;
match (d, tl) with
| UVar v, tl -> aux ~user_dtype v tl
| Arrow (_, Variadic, _, t), [] -> (t, b)
| t, [] -> (t, b)
| Arrow (Input, v, l, r), h :: tl ->
Expand All @@ -366,10 +352,10 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
else if not ((max_exp <<= l) ~loc) then Good_call.set b b'
end
else if not ((dy <<= l_user) ~loc) then
raise (KError (Some hd, (Good_call.set_wrong ~p:(is_uvar l) b ~exp:l_user ~found:dy h; b)))
raise (KError (Some hd, (Good_call.set_wrong b ~exp:l_user ~found:dy h; b)))
else if not ((dy <<= l) ~loc) then (
(* If preconditions are not satisfied, we stop and return bottom *)
Good_call.set_wrong ~p:(is_uvar l) b ~exp:l ~found:dy h;
Good_call.set_wrong b ~exp:l ~found:dy h;
Format.eprintf "Invalid determinacy set b to wrong (%a)@." Good_call.pp b))

(* if Good_call.is_wrong b' then(
Expand Down Expand Up @@ -455,7 +441,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
Format.eprintf "Calling deduce output on %a@." ScopedTerm.pretty_ it;
let rec aux d args =
match (d, args) with
| UVar v, _ -> aux v args
| Arrow (Input, v, _, r), _ :: tl -> aux (choose_variadic v d r) tl
| Arrow (Output, v, l, r), hd :: tl ->

Expand All @@ -464,7 +449,7 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =

if Good_call.is_wrong gc && Aux.is_maximized ~loc l then aux (choose_variadic v d r) tl
else if Good_call.is_good gc && (det <<= l) ~loc then aux (choose_variadic v d r) tl
else if Good_call.is_good gc then raise (DetError (Some pred_name, Good_call.make ~p:false ~exp:l ~found:det hd))
else if Good_call.is_good gc then raise (DetError (Some pred_name, Good_call.make ~exp:l ~found:det hd))
else raise (DetError (Some pred_name, gc))
| _ -> ()
in
Expand All @@ -484,7 +469,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
in
let rec assume_fold ~was_input ~was_data ~loc ctx d (tl : ScopedTerm.t list) : unit =
match (d, tl) with
| UVar v, tl -> assume_fold ~was_input ~was_data ~loc ctx v tl
| _, [] -> ()
| Arrow (Input, v, l, r), h :: tl ->
assume ~was_input:true ctx l h;
Expand All @@ -510,8 +494,7 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
and assume_var ~is_var ~ctx ~loc d ((_,name,_) as s) tl =
let rec replace_signature_tgt ~with_ d' = function
| [] -> with_
| (_::xs as l) -> match d' with
| UVar v -> UVar (replace_signature_tgt ~with_ v l)
| _::xs -> match d' with
| Arrow (_, Variadic, _, _) -> replace_signature_tgt ~with_ d' xs
| Arrow (m, NotVariadic, l, r) -> Arrow (m, NotVariadic, l, replace_signature_tgt ~with_ r xs)
| _ -> error ~loc @@ Format.asprintf "replace_signature_tgt: Type error: found %a " pp_dtype d' in
Expand Down Expand Up @@ -547,7 +530,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
| Impl (R2L, _, _B) -> ()
| Lam (oname, _, c) -> (
match d with
| UVar b -> assume ~was_input ctx b c
| Arrow (Input, NotVariadic, l, r) ->
let ctx = BVar.add_oname ~new_:true ~loc oname (fun _ -> l) ctx in
assume ~was_input ctx r c
Expand All @@ -565,7 +547,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
and assume_output ~ctx ~var d tl : Uvar.t =
let rec assume_output d args var =
match (d, args) with
| UVar v, _ -> assume_output v args var
| Arrow (Input, v, _, r), _ :: tl -> assume_output (choose_variadic v d r) tl var
| Arrow (Output, v, l, r), hd :: tl ->
Format.eprintf "Call assume of %a with dtype:%a@." ScopedTerm.pretty hd pp_dtype l;
Expand All @@ -577,7 +558,6 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
and assume_input ~ctx ~var d tl : Uvar.t =
let rec assume_input d args var =
match (d, args) with
| UVar d, _ -> assume_input d args var
| Arrow (Input, v, l, r), hd :: tl ->
Format.eprintf "Call assume of %a with dtype:%a@." ScopedTerm.pretty hd pp_dtype l;
let var = assume ~was_input:true ~ctx ~var l hd in
Expand Down Expand Up @@ -649,7 +629,7 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =
try
let d, _ = check ~ctx d b in
let d' = Compilation.type_ass_2func_mut ~loc ta b.ty in
if not ((d <<= d') ~loc) then raise (CastError (None, Good_call.make ~p:false ~exp:d' ~found:d b));
if not ((d <<= d') ~loc) then raise (CastError (None, Good_call.make ~exp:d' ~found:d b));
(d, t)
with DetError x -> raise (FatalDetError x))
| Spill _ -> spill_err ~loc
Expand Down Expand Up @@ -761,7 +741,7 @@ let check_clause ~type_abbrevs:ta ~types ~unknown (t : ScopedTerm.t) : unit =

let det_pred = get_tl det_hd in
if not @@ (det_body <<= det_pred) ~loc then
raise (RelationalBody (Some pred_name, Good_call.make ~p:false ~exp:det_pred ~found:det_body err_atom));
raise (RelationalBody (Some pred_name, Good_call.make ~exp:det_pred ~found:det_body err_atom));
Format.eprintf "** Start checking outputs@.";
infer_output ~pred_name ~ctx:!ctx ~var hd;
det_pred
Expand Down
Loading
0