From 3c7555168d39ea31041c4f06fe9bffe389b9622b Mon Sep 17 00:00:00 2001 From: Luisa Cicolini <48860705+luisacicolini@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:04:14 +0000 Subject: [PATCH] feat: add `BitVec.(msb, getMsbD)_(rotateLeft, rotateRight)` (#6120) This PR adds theorems `BitVec.(getMsbD, msb)_(rotateLeft, rotateRight)`. We follow the same strategy taken for `getLsbD`, constructing the necessary auxilliary theorems first (relying on different hypotheses) and then generalizing. --------- Co-authored-by: Siddharth Co-authored-by: Tobias Grosser --- src/Init/Data/BitVec/Lemmas.lean | 106 ++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 3 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 954c13a7e942..9cf09579bc98 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -2611,7 +2611,7 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i apply getLsbD_ge omega -/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/ +/-- When `r < w`, we give a formula for `(x.rotateLeft r).getLsbD i`. -/ theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) : (x.rotateLeft r).getLsbD i = cond (i < r) @@ -2638,6 +2638,56 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) : if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by simp [← BitVec.getLsbD_eq_getElem, h] +/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/ +theorem mod_eq_sub_of_le_of_lt {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) : + x % w = x - w := by + rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)] + omega + +theorem getMsbD_rotateLeftAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) : + (x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by + rw [rotateLeftAux, getMsbD_or] + simp [show i < w - r by omega, Nat.add_comm] + +theorem getMsbD_rotateLeftAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ w - r) : + (x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by + simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega] + +/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/ +theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w): + (x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by + rcases w with rfl | w + · simp + · rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)] + by_cases h : n < (w + 1) - r + · simp [getMsbD_rotateLeftAux_of_lt h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega] + · simp [getMsbD_rotateLeftAux_of_ge <| Nat.ge_of_not_lt h] + by_cases h₁ : n < w + 1 + · simp only [h₁, decide_true, Bool.true_and] + have h₂ : (r + n) < 2 * (w + 1) := by omega + rw [mod_eq_sub_of_le_of_lt (by omega) (by omega)] + congr 1 + omega + · simp [h₁] + +theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} : + (x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by + rcases w with rfl | w + · simp + · by_cases h : r < w + · rw [getMsbD_rotateLeft_of_lt (by omega)] + · rw [← rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_lt (by apply Nat.mod_lt; simp)] + simp + +@[simp] +theorem msb_rotateLeft {m w : Nat} {x : BitVec w} : + (x.rotateLeft m).msb = x.getMsbD (m % w) := by + simp only [BitVec.msb, getMsbD_rotateLeft] + by_cases h : w = 0 + · simp [h] + · simp + omega + /-! ## Rotate Right -/ /-- @@ -2699,7 +2749,7 @@ theorem rotateRight_mod_eq_rotateRight {x : BitVec w} {r : Nat} : simp only [rotateRight, Nat.mod_mod] /-- When `r < w`, we give a formula for `(x.rotateRight r).getLsb i`. -/ -theorem getLsbD_rotateRight_of_le {x : BitVec w} {r i : Nat} (hr: r < w) : +theorem getLsbD_rotateRight_of_lt {x : BitVec w} {r i : Nat} (hr: r < w) : (x.rotateRight r).getLsbD i = cond (i < w - r) (x.getLsbD (r + i)) @@ -2717,7 +2767,7 @@ theorem getLsbD_rotateRight {x : BitVec w} {r i : Nat} : (decide (i < w) && x.getLsbD (i - (w - (r % w)))) := by rcases w with ⟨rfl, w⟩ · simp - · rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_le (Nat.mod_lt _ (by omega))] + · rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_lt (Nat.mod_lt _ (by omega))] @[simp] theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) : @@ -2725,6 +2775,56 @@ theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) : simp only [← BitVec.getLsbD_eq_getElem] simp [getLsbD_rotateRight, h] +theorem getMsbD_rotateRightAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) : + (x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by + rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight] + simp [show i < r by omega] + +theorem getMsbD_rotateRightAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ r) : + (x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by + simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) ≥ w by omega] + +/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/ +@[simp] +theorem getMsbD_rotateRight_of_lt {w n m : Nat} {x : BitVec w} (hr : m < w): + (x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w) + then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by + rcases w with rfl | w + · simp + · rw [rotateRight_eq_rotateRightAux_of_lt (by omega)] + by_cases h : n < m + · simp only [getMsbD_rotateRightAux_of_lt h, show n < w + 1 by omega, decide_true, + show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, ↓reduceIte, + show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and] + congr 1 + omega + · simp [h, getMsbD_rotateRightAux_of_ge <| Nat.ge_of_not_lt h] + by_cases h₁ : n < w + 1 + · simp [h, h₁, decide_true, Bool.true_and, Nat.mod_eq_of_lt hr] + · simp [h₁] + +@[simp] +theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} : + (x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w) + then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by + rcases w with rfl | w + · simp + · by_cases h₀ : m < w + · rw [getMsbD_rotateRight_of_lt (by omega)] + · rw [← rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_lt (by apply Nat.mod_lt; simp)] + simp + +@[simp] +theorem msb_rotateRight {r w : Nat} {x : BitVec w} : + (x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by + simp only [BitVec.msb, getMsbD_rotateRight] + by_cases h₀ : 0 < w + · simp only [h₀, decide_true, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and, + ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq] + intro h₁ + simp [h₁] + · simp [show w = 0 by omega] + /- ## twoPow -/ theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by