diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 648e6421a38e..6d4a816bb8e7 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -2520,6 +2520,17 @@ theorem mul_eq_and {a b : BitVec 1} : a * b = a &&& b := by have hb : b = 0 ∨ b = 1 := eq_zero_or_eq_one _ rcases ha with h | h <;> (rcases hb with h' | h' <;> (simp [h, h'])) +@[simp] protected theorem neg_mul (x y : BitVec w) : -x * y = -(x * y) := by + apply eq_of_toInt_eq + simp [toInt_mul, toInt_neg, Int.bmod_neg_bmod] + +@[simp] protected theorem mul_neg (x y : BitVec w) : x * -y = -(x * y) := by + rw [BitVec.mul_comm, BitVec.neg_mul, BitVec.mul_comm] + +protected theorem neg_mul_neg (x y : BitVec w) : -x * -y = x * y := by simp + +protected theorem neg_mul_comm (x y : BitVec w) : -x * y = x * -y := by simp + /-! ### le and lt -/ @[bv_toNat] theorem le_def {x y : BitVec n} : diff --git a/src/Init/Data/Int/DivModLemmas.lean b/src/Init/Data/Int/DivModLemmas.lean index b8f121043a55..28bf0bde38ed 100644 --- a/src/Init/Data/Int/DivModLemmas.lean +++ b/src/Init/Data/Int/DivModLemmas.lean @@ -1194,6 +1194,16 @@ theorem bmod_sub_bmod_congr : Int.bmod (Int.bmod x n - y) n = Int.bmod (x - y) n rw [Int.sub_eq_add_neg, Int.sub_eq_add_neg, Int.add_right_comm, ←Int.sub_eq_add_neg, ← Int.sub_eq_add_neg] simp [emod_sub_bmod_congr] +theorem add_bmod_eq_add_bmod_right (i : Int) + (H : bmod x n = bmod y n) : bmod (x + i) n = bmod (y + i) n := by + rw [← bmod_add_bmod_congr, ← @bmod_add_bmod_congr y, H] + +theorem bmod_add_cancel_right (i : Int) : bmod (x + i) n = bmod (y + i) n ↔ bmod x n = bmod y n := + ⟨fun H => by + have := add_bmod_eq_add_bmod_right (-i) H + rwa [Int.add_neg_cancel_right, Int.add_neg_cancel_right] at this, + fun H => by rw [← bmod_add_bmod_congr, H, bmod_add_bmod_congr]⟩ + @[simp] theorem add_bmod_bmod : Int.bmod (x + Int.bmod y n) n = Int.bmod (x + y) n := by rw [Int.add_comm x, Int.bmod_add_bmod_congr, Int.add_comm y] @@ -1348,3 +1358,7 @@ theorem bmod_natAbs_plus_one (x : Int) (w : 1 < x.natAbs) : bmod x (x.natAbs + 1 all_goals decide · exact ofNat_nonneg x · exact succ_ofNat_pos (x + 1) + +theorem bmod_neg_bmod : bmod (-(bmod x n)) n = bmod (-x) n := by + apply (bmod_add_cancel_right x).mp + rw [Int.add_left_neg, ← add_bmod_bmod, Int.add_left_neg] diff --git a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean index 92c8d5e37698..67765cdd0903 100644 --- a/src/Std/Tactic/BVDecide/Normalize/BitVec.lean +++ b/src/Std/Tactic/BVDecide/Normalize/BitVec.lean @@ -257,6 +257,14 @@ theorem BitVec.add_const_right' (a b c : BitVec w) : (a + b) + c = (b + c) + a : attribute [bv_normalize] BitVec.mul_zero attribute [bv_normalize] BitVec.zero_mul +@[bv_normalize] +theorem BitVec.neg_mul' (x y : BitVec w) : (~~~x + 1#w) * y = ~~~(x * y) + 1#w := by + rw [← BitVec.neg_eq_not_add, ← BitVec.neg_eq_not_add, BitVec.neg_mul] + +@[bv_normalize] +theorem BitVec.mul_neg' (x y : BitVec w) : x * (~~~y + 1#w) = ~~~(x * y) + 1#w := by + rw [← BitVec.neg_eq_not_add, ← BitVec.neg_eq_not_add, BitVec.mul_neg] + attribute [bv_normalize] BitVec.shiftLeft_zero attribute [bv_normalize] BitVec.zero_shiftLeft diff --git a/tests/lean/run/bv_decide_rewriter.lean b/tests/lean/run/bv_decide_rewriter.lean index 882dc72536a6..afd2df14a92a 100644 --- a/tests/lean/run/bv_decide_rewriter.lean +++ b/tests/lean/run/bv_decide_rewriter.lean @@ -88,6 +88,14 @@ example {x : BitVec 16} : x / (BitVec.ofNat 16 8) = x >>> 3 := by bv_normalize example {x y : Bool} (h1 : x && y) : x || y := by bv_normalize example (a b c: Bool) : (if a then b else c) = (if !a then c else b) := by bv_normalize +-- neg_mul' / mul_neg' +example (x y : BitVec 16) : (-x) * y = -(x * y) := by bv_normalize +example (x y : BitVec 16) : x * (-y) = -(x * y) := by bv_normalize +example (x y : BitVec 16) : -x * -y = x * y := by bv_normalize +example (x y : BitVec 16) : (~~~x + 1) * y = ~~~(x * y) + 1 := by bv_normalize +example (x y : BitVec 16) : x * (~~~y + 1) = ~~~(x * y) + 1 := by bv_normalize +example (x y : BitVec 16) : (~~~x + 1) * (~~~y + 1) = x * y := by bv_normalize + section example (x y : BitVec 256) : x * y = y * x := by