Skip to content

Commit

Permalink
feat: add BitVec.(msb, getMsbD)_(rotateLeft, rotateRight) (#6120)
Browse files Browse the repository at this point in the history
This PR adds theorems `BitVec.(getMsbD, msb)_(rotateLeft, rotateRight)`.

We follow the same strategy taken for `getLsbD`, constructing the
necessary auxilliary theorems first (relying on different hypotheses)
and then generalizing.

---------

Co-authored-by: Siddharth <[email protected]>
Co-authored-by: Tobias Grosser <[email protected]>
  • Loading branch information
3 people authored Nov 19, 2024
1 parent 5eef3d2 commit 3c75551
Showing 1 changed file with 103 additions and 3 deletions.
106 changes: 103 additions & 3 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2611,7 +2611,7 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i
apply getLsbD_ge
omega

/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/
/-- When `r < w`, we give a formula for `(x.rotateLeft r).getLsbD i`. -/
theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
(x.rotateLeft r).getLsbD i =
cond (i < r)
Expand All @@ -2638,6 +2638,56 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) :
if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by
simp [← BitVec.getLsbD_eq_getElem, h]

/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/
theorem mod_eq_sub_of_le_of_lt {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) :
x % w = x - w := by
rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)]
omega

theorem getMsbD_rotateLeftAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) :
(x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by
rw [rotateLeftAux, getMsbD_or]
simp [show i < w - r by omega, Nat.add_comm]

theorem getMsbD_rotateLeftAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ w - r) :
(x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by
simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega]

/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/
theorem getMsbD_rotateLeft_of_lt {n w : Nat} {x : BitVec w} (hi : r < w):
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
· simp
· rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)]
by_cases h : n < (w + 1) - r
· simp [getMsbD_rotateLeftAux_of_lt h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega]
· simp [getMsbD_rotateLeftAux_of_ge <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp only [h₁, decide_true, Bool.true_and]
have h₂ : (r + n) < 2 * (w + 1) := by omega
rw [mod_eq_sub_of_le_of_lt (by omega) (by omega)]
congr 1
omega
· simp [h₁]

theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} :
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
· simp
· by_cases h : r < w
· rw [getMsbD_rotateLeft_of_lt (by omega)]
· rw [← rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_lt (by apply Nat.mod_lt; simp)]
simp

@[simp]
theorem msb_rotateLeft {m w : Nat} {x : BitVec w} :
(x.rotateLeft m).msb = x.getMsbD (m % w) := by
simp only [BitVec.msb, getMsbD_rotateLeft]
by_cases h : w = 0
· simp [h]
· simp
omega

/-! ## Rotate Right -/

/--
Expand Down Expand Up @@ -2699,7 +2749,7 @@ theorem rotateRight_mod_eq_rotateRight {x : BitVec w} {r : Nat} :
simp only [rotateRight, Nat.mod_mod]

/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsb i`. -/
theorem getLsbD_rotateRight_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
theorem getLsbD_rotateRight_of_lt {x : BitVec w} {r i : Nat} (hr: r < w) :
(x.rotateRight r).getLsbD i =
cond (i < w - r)
(x.getLsbD (r + i))
Expand All @@ -2717,14 +2767,64 @@ theorem getLsbD_rotateRight {x : BitVec w} {r i : Nat} :
(decide (i < w) && x.getLsbD (i - (w - (r % w)))) := by
rcases w with ⟨rfl, w⟩
· simp
· rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_le (Nat.mod_lt _ (by omega))]
· rw [← rotateRight_mod_eq_rotateRight, getLsbD_rotateRight_of_lt (Nat.mod_lt _ (by omega))]

@[simp]
theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
(x.rotateRight r)[i] = if h' : i < w - (r % w) then x[(r % w) + i] else x[(i - (w - (r % w)))] := by
simp only [← BitVec.getLsbD_eq_getElem]
simp [getLsbD_rotateRight, h]

theorem getMsbD_rotateRightAux_of_lt {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) :
(x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by
rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight]
simp [show i < r by omega]

theorem getMsbD_rotateRightAux_of_ge {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ r) :
(x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by
simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) ≥ w by omega]

/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/
@[simp]
theorem getMsbD_rotateRight_of_lt {w n m : Nat} {x : BitVec w} (hr : m < w):
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· rw [rotateRight_eq_rotateRightAux_of_lt (by omega)]
by_cases h : n < m
· simp only [getMsbD_rotateRightAux_of_lt h, show n < w + 1 by omega, decide_true,
show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, ↓reduceIte,
show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and]
congr 1
omega
· simp [h, getMsbD_rotateRightAux_of_ge <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp [h, h₁, decide_true, Bool.true_and, Nat.mod_eq_of_lt hr]
· simp [h₁]

@[simp]
theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} :
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· by_cases h₀ : m < w
· rw [getMsbD_rotateRight_of_lt (by omega)]
· rw [← rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_lt (by apply Nat.mod_lt; simp)]
simp

@[simp]
theorem msb_rotateRight {r w : Nat} {x : BitVec w} :
(x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by
simp only [BitVec.msb, getMsbD_rotateRight]
by_cases h₀ : 0 < w
· simp only [h₀, decide_true, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and,
ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq]
intro h₁
simp [h₁]
· simp [show w = 0 by omega]

/- ## twoPow -/

theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by
Expand Down

0 comments on commit 3c75551

Please sign in to comment.