From 918924c16b82b304857f07dfb784e7e25ac5c7a8 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Fri, 10 Jan 2025 23:23:58 +0000 Subject: [PATCH] feat: `BitVec.{toFin, toInt, msb}_umod` (#6404) This PR adds a `toFin` and `msb` lemma for unsigned bitvector modulus. Similar to #6402, we don't provide a general `toInt_umod` lemmas, but instead choose to provide more specialized rewrites, with extra side-conditions. --------- Co-authored-by: Kim Morrison --- src/Init/Data/BitVec/Lemmas.lean | 60 +++++++++++++++++++++++++++++++ src/Init/Data/Nat/Div/Lemmas.lean | 5 +++ 2 files changed, 65 insertions(+) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 8b49a5d8beb0..377ab5cc885f 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -11,6 +11,7 @@ import Init.Data.Fin.Lemmas import Init.Data.Nat.Lemmas import Init.Data.Nat.Div.Lemmas import Init.Data.Nat.Mod +import Init.Data.Nat.Div.Lemmas import Init.Data.Int.Bitwise.Lemmas import Init.Data.Int.Pow @@ -99,6 +100,12 @@ theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by theorem eq_of_toNat_eq {n} : ∀ {x y : BitVec n}, x.toNat = y.toNat → x = y | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl +/-- Prove nonequality of bitvectors in terms of nat operations. -/ +theorem toNat_ne_iff_ne {n} {x y : BitVec n} : x.toNat ≠ y.toNat ↔ x ≠ y := by + constructor + · rintro h rfl; apply h rfl + · intro h h_eq; apply h <| eq_of_toNat_eq h_eq + @[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl @[bv_toNat] theorem toNat_eq {x y : BitVec n} : x = y ↔ x.toNat = y.toNat := @@ -2693,6 +2700,10 @@ theorem umod_def {x y : BitVec n} : theorem toNat_umod {x y : BitVec n} : (x % y).toNat = x.toNat % y.toNat := rfl +@[simp] +theorem toFin_umod {x y : BitVec w} : + (x % y).toFin = x.toFin % y.toFin := rfl + @[simp] theorem umod_zero {x : BitVec n} : x % 0#n = x := by simp [umod_def] @@ -2720,6 +2731,55 @@ theorem umod_eq_and {x y : BitVec 1} : x % y = x &&& (~~~y) := by rcases hy with rfl | rfl <;> rfl +theorem umod_eq_of_lt {x y : BitVec w} (h : x < y) : + x % y = x := by + apply eq_of_toNat_eq + simp [Nat.mod_eq_of_lt h] + +@[simp] +theorem msb_umod {x y : BitVec w} : + (x % y).msb = (x.msb && (x < y || y == 0#w)) := by + rw [msb_eq_decide, toNat_umod] + cases msb_x : x.msb + · suffices x.toNat % y.toNat < 2 ^ (w - 1) by simpa + calc + x.toNat % y.toNat ≤ x.toNat := by apply Nat.mod_le + _ < 2 ^ (w - 1) := by simpa [msb_eq_decide] using msb_x + . by_cases hy : y = 0 + · simp_all [msb_eq_decide] + · suffices 2 ^ (w - 1) ≤ x.toNat % y.toNat ↔ x < y by simp_all + by_cases x_lt_y : x < y + . simp_all [Nat.mod_eq_of_lt x_lt_y, msb_eq_decide] + · suffices x.toNat % y.toNat < 2 ^ (w - 1) by + simpa [x_lt_y] + have y_le_x : y.toNat ≤ x.toNat := by + simpa using x_lt_y + replace hy : y.toNat ≠ 0 := + toNat_ne_iff_ne.mpr hy + by_cases msb_y : y.toNat < 2 ^ (w - 1) + · have : x.toNat % y.toNat < y.toNat := Nat.mod_lt _ (by omega) + omega + · rcases w with _|w + · contradiction + simp only [Nat.add_one_sub_one] + replace msb_y : 2 ^ w ≤ y.toNat := by + simpa using msb_y + have : y.toNat ≤ y.toNat * (x.toNat / y.toNat) := by + apply Nat.le_mul_of_pos_right + apply Nat.div_pos y_le_x + omega + have : x.toNat % y.toNat ≤ x.toNat - y.toNat := by + rw [Nat.mod_eq_sub]; omega + omega + +theorem toInt_umod {x y : BitVec w} : + (x % y).toInt = (x.toNat % y.toNat : Int).bmod (2 ^ w) := by + simp [toInt_eq_toNat_bmod] + +theorem toInt_umod_of_msb {x y : BitVec w} (h : x.msb = false) : + (x % y).toInt = x.toInt % y.toNat := by + simp [toInt_eq_msb_cond, h] + /-! ### smtUDiv -/ theorem smtUDiv_eq (x y : BitVec w) : smtUDiv x y = if y = 0#w then allOnes w else x / y := by diff --git a/src/Init/Data/Nat/Div/Lemmas.lean b/src/Init/Data/Nat/Div/Lemmas.lean index d5a26a0950ba..cf312c6a0fff 100644 --- a/src/Init/Data/Nat/Div/Lemmas.lean +++ b/src/Init/Data/Nat/Div/Lemmas.lean @@ -49,6 +49,11 @@ theorem lt_div_mul_self (h : 0 < k) (w : k ≤ x) : x - k < x / k * k := by have : x % k < k := mod_lt x h omega +theorem div_pos (hba : b ≤ a) (hb : 0 < b) : 0 < a / b := by + cases b + · contradiction + · simp [Nat.pos_iff_ne_zero, div_eq_zero_iff_lt, hba] + theorem div_le_div_left (hcb : c ≤ b) (hc : 0 < c) : a / b ≤ a / c := (Nat.le_div_iff_mul_le hc).2 <| Nat.le_trans (Nat.mul_le_mul_left _ hcb) (Nat.div_mul_le_self a b)