From 2e6b59f8d86837fc424611b0ed3a1e48f3d2ee34 Mon Sep 17 00:00:00 2001
From: Danny Willems <be.danny.willems@gmail.com>
Date: Mon, 26 Aug 2024 17:21:47 -0700
Subject: [PATCH] MVPoly: implement Neg

---
 mvpoly/src/prime.rs   | 27 ++++++++++++++++++++++++++-
 mvpoly/tests/prime.rs | 35 +++++++++++++++++++++++++++++++++++
 2 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/mvpoly/src/prime.rs b/mvpoly/src/prime.rs
index cc6d7a043c..41c42feb4b 100644
--- a/mvpoly/src/prime.rs
+++ b/mvpoly/src/prime.rs
@@ -144,7 +144,7 @@
 use std::{
     collections::HashMap,
     fmt::{Debug, Formatter, Result},
-    ops::{Add, Mul, Sub},
+    ops::{Add, Mul, Neg, Sub},
 };
 
 use ark_ff::{One, PrimeField, Zero};
@@ -395,6 +395,31 @@ impl<F: PrimeField, const N: usize, const D: usize> Sub<&Dense<F, N, D>> for &De
     }
 }
 
+// Negation
+impl<F: PrimeField, const N: usize, const D: usize> Neg for Dense<F, N, D> {
+    type Output = Self;
+
+    fn neg(self) -> Self::Output {
+        let mut result = Dense::new();
+        for i in 0..self.coeff.len() {
+            result.coeff[i] = -self.coeff[i];
+        }
+        result
+    }
+}
+
+impl<F: PrimeField, const N: usize, const D: usize> Neg for &Dense<F, N, D> {
+    type Output = Dense<F, N, D>;
+
+    fn neg(self) -> Self::Output {
+        let mut result = Dense::new();
+        for i in 0..self.coeff.len() {
+            result.coeff[i] = -self.coeff[i];
+        }
+        result
+    }
+}
+
 // Multiplication
 impl<F: PrimeField, const N: usize, const D: usize> Mul<Dense<F, N, D>> for Dense<F, N, D> {
     type Output = Self;
diff --git a/mvpoly/tests/prime.rs b/mvpoly/tests/prime.rs
index 476f08b507..86b88559bb 100644
--- a/mvpoly/tests/prime.rs
+++ b/mvpoly/tests/prime.rs
@@ -202,3 +202,38 @@ fn test_sub_zero() {
     let p2 = p1.clone() - zero.clone();
     assert_eq!(p1.clone(), p2);
 }
+
+#[test]
+fn test_neg() {
+    let mut rng = o1_utils::tests::make_test_rng(None);
+    let p1 = Dense::<Fp, 3, 4>::random(&mut rng);
+    let p2 = -p1.clone();
+
+    // Test that p1 + (-p1) = 0
+    let sum = p1.clone() + p2.clone();
+    assert_eq!(sum, Dense::<Fp, 3, 4>::zero());
+
+    // Test that -(-p1) = p1
+    let p3 = -p2;
+    assert_eq!(p1, p3);
+
+    // Test negation of zero
+    let zero = Dense::<Fp, 3, 4>::zero();
+    let neg_zero = -zero.clone();
+    assert_eq!(zero, neg_zero);
+}
+
+#[test]
+fn test_neg_ref() {
+    let mut rng = o1_utils::tests::make_test_rng(None);
+    let p1 = Dense::<Fp, 3, 4>::random(&mut rng);
+    let p2 = -&p1;
+
+    // Test that p1 + (-&p1) = 0
+    let sum = p1.clone() + p2.clone();
+    assert_eq!(sum, Dense::<Fp, 3, 4>::zero());
+
+    // Test that -(-&p1) = p1
+    let p3 = -&p2;
+    assert_eq!(p1, p3);
+}