Skip to content

Commit

Permalink
TC: Add context to failed constraint errors
Browse files Browse the repository at this point in the history
As part of this ensure that location information is correctly
propagated through type-synonym expansion
  • Loading branch information
Alasdair committed Jan 6, 2025
1 parent 1789ad8 commit c2fe0d4
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 57 deletions.
99 changes: 51 additions & 48 deletions src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ let exp_loc = function E_aux (_, (l, _)) -> l

let nexp_loc = function Nexp_aux (_, l) -> l

let constraint_loc = function NC_aux (_, l) -> l

let gen_loc = function Parse_ast.Generated l -> Parse_ast.Generated l | l -> Parse_ast.Generated l

let rec is_gen_loc = function
Expand Down Expand Up @@ -2089,63 +2091,64 @@ let extern_assoc backend ext =
(* 1. Substitutions *)
(**************************************************************************)

let rec nexp_subst sv subst = function
| Nexp_aux (Nexp_var kid, _) as nexp -> begin
match subst with A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> n | _ -> nexp
end
| Nexp_aux (nexp, l) -> Nexp_aux (nexp_subst_aux sv subst nexp, l)
let mk_subst_arg = function
| A_typ typ -> A_aux (A_typ typ, typ_loc typ)
| A_nexp n -> A_aux (A_nexp n, nexp_loc n)
| A_bool b -> A_aux (A_bool b, constraint_loc b)

and nexp_subst_aux sv subst = function
let rec nexp_subst sv subst (Nexp_aux (n, l)) =
let wrap aux = Nexp_aux (aux, l) in
match n with
| Nexp_var kid -> begin
match subst with A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n | _ -> Nexp_var kid
match subst with A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> n | _ -> wrap (Nexp_var kid)
end
| Nexp_id id -> Nexp_id id
| Nexp_constant c -> Nexp_constant c
| Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2)
| Nexp_sum (nexp1, nexp2) -> Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2)
| Nexp_minus (nexp1, nexp2) -> Nexp_minus (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2)
| Nexp_app (id, nexps) -> Nexp_app (id, List.map (nexp_subst sv subst) nexps)
| Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp)
| Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp)
| Nexp_if (i, t, e) -> Nexp_if (constraint_subst sv subst i, nexp_subst sv subst t, nexp_subst sv subst e)

and constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l)

and constraint_subst_aux l sv subst = function
| NC_id id -> NC_id id
| NC_equal (arg1, arg2) -> NC_equal (typ_arg_subst sv subst arg1, typ_arg_subst sv subst arg2)
| NC_not_equal (arg1, arg2) -> NC_not_equal (typ_arg_subst sv subst arg1, typ_arg_subst sv subst arg2)
| NC_ge (n1, n2) -> NC_ge (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_gt (n1, n2) -> NC_gt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_le (n1, n2) -> NC_le (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_lt (n1, n2) -> NC_lt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_set (n, ints) -> NC_set (nexp_subst sv subst n, ints)
| NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args)
| Nexp_id id -> wrap (Nexp_id id)
| Nexp_constant c -> wrap (Nexp_constant c)
| Nexp_times (nexp1, nexp2) -> wrap (Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2))
| Nexp_sum (nexp1, nexp2) -> wrap (Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2))
| Nexp_minus (nexp1, nexp2) -> wrap (Nexp_minus (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2))
| Nexp_app (id, nexps) -> wrap (Nexp_app (id, List.map (nexp_subst sv subst) nexps))
| Nexp_exp nexp -> wrap (Nexp_exp (nexp_subst sv subst nexp))
| Nexp_neg nexp -> wrap (Nexp_neg (nexp_subst sv subst nexp))
| Nexp_if (i, t, e) -> wrap (Nexp_if (constraint_subst sv subst i, nexp_subst sv subst t, nexp_subst sv subst e))

and constraint_subst sv subst (NC_aux (nc, l)) =
let wrap aux = NC_aux (aux, l) in
match nc with
| NC_id id -> wrap (NC_id id)
| NC_equal (arg1, arg2) -> wrap (NC_equal (typ_arg_subst sv subst arg1, typ_arg_subst sv subst arg2))
| NC_not_equal (arg1, arg2) -> wrap (NC_not_equal (typ_arg_subst sv subst arg1, typ_arg_subst sv subst arg2))
| NC_ge (n1, n2) -> wrap (NC_ge (nexp_subst sv subst n1, nexp_subst sv subst n2))
| NC_gt (n1, n2) -> wrap (NC_gt (nexp_subst sv subst n1, nexp_subst sv subst n2))
| NC_le (n1, n2) -> wrap (NC_le (nexp_subst sv subst n1, nexp_subst sv subst n2))
| NC_lt (n1, n2) -> wrap (NC_lt (nexp_subst sv subst n1, nexp_subst sv subst n2))
| NC_set (n, ints) -> wrap (NC_set (nexp_subst sv subst n, ints))
| NC_or (nc1, nc2) -> wrap (NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2))
| NC_and (nc1, nc2) -> wrap (NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2))
| NC_app (id, args) -> wrap (NC_app (id, List.map (typ_arg_subst sv subst) args))
| NC_var kid -> begin
match subst with A_aux (A_bool nc, _) when Kid.compare kid sv = 0 -> unaux_constraint nc | _ -> NC_var kid
match subst with A_aux (A_bool nc, _) when Kid.compare kid sv = 0 -> nc | _ -> wrap (NC_var kid)
end
| NC_false -> NC_false
| NC_true -> NC_true

and typ_subst sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_aux sv subst typ, l)

and typ_subst_aux sv subst = function
| Typ_internal_unknown -> Typ_internal_unknown
| Typ_id v -> Typ_id v
| NC_false -> wrap NC_false
| NC_true -> wrap NC_true

and typ_subst sv subst (Typ_aux (typ, l)) =
let wrap aux = Typ_aux (aux, l) in
match typ with
| Typ_internal_unknown -> wrap Typ_internal_unknown
| Typ_id v -> wrap (Typ_id v)
| Typ_var kid -> begin
match subst with A_aux (A_typ typ, _) when Kid.compare kid sv = 0 -> unaux_typ typ | _ -> Typ_var kid
match subst with A_aux (A_typ typ, _) when Kid.compare kid sv = 0 -> typ | _ -> wrap (Typ_var kid)
end
| Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ)
| Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2)
| Typ_tuple typs -> Typ_tuple (List.map (typ_subst sv subst) typs)
| Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args)
| Typ_fn (arg_typs, ret_typ) -> wrap (Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ))
| Typ_bidir (typ1, typ2) -> wrap (Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2))
| Typ_tuple typs -> wrap (Typ_tuple (List.map (typ_subst sv subst) typs))
| Typ_app (f, args) -> wrap (Typ_app (f, List.map (typ_arg_subst sv subst) args))
| Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) ->
Typ_exist (kopts, nc, typ)
| Typ_exist (kopts, nc, typ) -> Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ)
wrap (Typ_exist (kopts, nc, typ))
| Typ_exist (kopts, nc, typ) -> wrap (Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ))

and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l)
and typ_arg_subst sv subst (A_aux (arg, _)) = mk_subst_arg (typ_arg_subst_aux sv subst arg)

and typ_arg_subst_aux sv subst = function
| A_nexp nexp -> A_nexp (nexp_subst sv subst nexp)
Expand Down
1 change: 1 addition & 0 deletions src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ val pat_loc : 'a pat -> Parse_ast.l
val mpat_loc : 'a mpat -> Parse_ast.l
val exp_loc : 'a exp -> Parse_ast.l
val nexp_loc : nexp -> Parse_ast.l
val constraint_loc : n_constraint -> Parse_ast.l
val def_loc : ('a, 'b) def -> Parse_ast.l

(** {1 Printing utilities}
Expand Down
5 changes: 5 additions & 0 deletions src/lib/reporting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ let rec simp_loc = function
| Parse_ast.Hint (_, l1, l2) -> begin match simp_loc l1 with None -> simp_loc l2 | pos -> pos end
| Parse_ast.Range (p1, p2) -> Some (p1, p2)

let rec is_unknown_loc = function
| Parse_ast.Unknown -> true
| Parse_ast.Range _ -> false
| Parse_ast.Generated l | Parse_ast.Unique (_, l) | Parse_ast.Hint (_, _, l) -> is_unknown_loc l

let loc_range_to_src (p1 : Lexing.position) (p2 : Lexing.position) =
(fun contents -> String.sub contents p1.pos_cnum (p2.pos_cnum - p1.pos_cnum)) (Util.read_whole_file p1.pos_fname)

Expand Down
4 changes: 4 additions & 0 deletions src/lib/reporting.mli
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ val loc_file : Parse_ast.l -> string option
(** Reduce a location to a pair of positions if possible *)
val simp_loc : Ast.l -> (Lexing.position * Lexing.position) option

(** Returns true if a loc is [Unknown]. In the case of [Hint] location
only checks the base location. *)
val is_unknown_loc : Ast.l -> bool

(** [loc_range_to_src] returns the source code text of a range **)
val loc_range_to_src : Lexing.position -> Lexing.position -> string

Expand Down
18 changes: 13 additions & 5 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,17 @@ let rec subtyp l env typ1 typ2 =
)
)

and subtyp_arg l env (A_aux (aux1, _) as arg1) (A_aux (aux2, _) as arg2) =
and subtyp_arg l env (A_aux (aux1, arg_l1) as arg1) (A_aux (aux2, arg_l2) as arg2) =
typ_print
(lazy (("Subtype arg " |> Util.green |> Util.clear) ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2));
let raise_failed_constraint nc =
typ_raise l (Err_failed_constraint (nc, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env))
(* If we don't have precise locations for both arguments, then
don't try to use an argument location as the base location for
the type error, as there are a few confusing corner cases. *)
let l = if Reporting.is_unknown_loc arg_l1 || Reporting.is_unknown_loc arg_l2 then l else arg_l2 in
let derived_from = if Reporting.is_unknown_loc arg_l1 then None else Some arg_l1 in
typ_raise l
(Err_failed_constraint (nc, derived_from, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env))
in
match (aux1, aux2) with
| A_nexp n1, A_nexp n2 ->
Expand Down Expand Up @@ -3339,7 +3345,9 @@ and infer_lexp env (LE_aux (lexp_aux, (l, uannot)) as lexp) =
annot_lexp (LE_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (bitvector_typ slice_len)
else
typ_raise l
(Err_failed_constraint (check, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env))
(Err_failed_constraint
(check, None, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
)
| _ -> typ_error l "Cannot assign slice of non vector type"
end
| LE_vector (v_lexp, exp) -> begin
Expand All @@ -3355,7 +3363,7 @@ and infer_lexp env (LE_aux (lexp_aux, (l, uannot)) as lexp) =
else
typ_raise l
(Err_failed_constraint
(bounds_check, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
(bounds_check, None, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
)
| Typ_app (id, [A_aux (A_nexp len, _)]) when Id.compare id (mk_id "bitvector") = 0 ->
let inferred_exp = infer_exp env exp in
Expand All @@ -3366,7 +3374,7 @@ and infer_lexp env (LE_aux (lexp_aux, (l, uannot)) as lexp) =
else
typ_raise l
(Err_failed_constraint
(bounds_check, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
(bounds_check, None, Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
)
| Typ_id id -> begin
match exp with
Expand Down
12 changes: 10 additions & 2 deletions src/lib/type_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,13 @@ let message_of_type_error type_error =
],
None
)
| Err_failed_constraint (check, locals, _, ncs) ->
| Err_failed_constraint (check, derived_from, locals, _, ncs) ->
let simplified = constraint_simp check in
let add_derivation msg =
match derived_from with
| Some l -> Seq [msg; Line ""; Location ("constraint from ", Some "This type argument", l, Seq [])]
| None -> msg
in
begin
match simplified with
| NC_aux (NC_false, _) ->
Expand All @@ -381,7 +386,10 @@ let message_of_type_error type_error =
),
None
)
| _ -> (Line ("Failed to prove constraint: " ^ string_of_n_constraint (constraint_simp check)), None)
| _ ->
( Line ("Failed to prove constraint: " ^ string_of_n_constraint (constraint_simp check)) |> add_derivation,
None
)
end
| Err_subtype (typ1, typ2, nc, all_constraints, tyvars) ->
let nc = Option.map constraint_simp nc in
Expand Down
3 changes: 2 additions & 1 deletion src/lib/type_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ type constraint_reason = (l * string) option
type type_error =
| Err_no_overloading of id * (id * Parse_ast.l * type_error) list
| Err_unresolved_quants of id * quant_item list * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_failed_constraint of n_constraint * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_failed_constraint of
n_constraint * Parse_ast.l option * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_subtype of typ * typ * n_constraint option * (constraint_reason * n_constraint) list * type_variables
| Err_no_num_ident of id
| Err_other of string
Expand Down
3 changes: 2 additions & 1 deletion src/lib/type_internal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ type type_variables = { vars : (Ast.l * kind_aux) KBindings.t; shadows : int KBi
type type_error =
| Err_no_overloading of id * (id * Parse_ast.l * type_error) list
| Err_unresolved_quants of id * quant_item list * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_failed_constraint of n_constraint * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_failed_constraint of
n_constraint * Parse_ast.l option * (mut * typ) Bindings.t * type_variables * n_constraint list
| Err_subtype of typ * typ * n_constraint option * (constraint_reason * n_constraint) list * type_variables
| Err_no_num_ident of id
| Err_other of string
Expand Down
9 changes: 9 additions & 0 deletions test/typecheck/fail/issue853.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Type error:
fail/issue853.sail:4.30-32:
4 |function foo(addr : bitvector(64)) -> unit = ()
 | ^^
 | Failed to prove constraint: 32 == 64
 |
 | constraint from fail/issue853.sail:2.20-22:
 | 2 |val foo : bitvector(32) -> unit
 |  | ^^ This type argument
4 changes: 4 additions & 0 deletions test/typecheck/fail/issue853.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

val foo : bitvector(32) -> unit

function foo(addr : bitvector(64)) -> unit = ()
9 changes: 9 additions & 0 deletions test/typecheck/fail/issue853_2.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Type error:
fail/issue853_2.sail:6.25-27:
6 |function foo(addr : bits(64)) -> unit = ()
 | ^^
 | Failed to prove constraint: 32 == 64
 |
 | constraint from fail/issue853_2.sail:4.20-22:
 | 4 |val foo : bitvector(32) -> unit
 |  | ^^ This type argument
6 changes: 6 additions & 0 deletions test/typecheck/fail/issue853_2.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

type bits('n) = bitvector('n)

val foo : bitvector(32) -> unit

function foo(addr : bits(64)) -> unit = ()

0 comments on commit c2fe0d4

Please sign in to comment.