From ad39e8f34d8cfdc934b38032b33a9ad868330bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Wed, 13 Nov 2024 12:34:04 +0100 Subject: [PATCH] Add mutating arithmetic for `SRow`s (#1659) * Make `scale_row!` follow its docstring i.e. don't throw on zero scalars, and coerce scalars if needed * Fix a docstring typo * Let `add_scaled_row` coerce the scalar * Add mutating arithmetics for SRow * Skip deepcopy in addmul! in case of aliasing * Add `submul!` * Add tests * Some fixes (tests run now) * Comment out tests * Adapt to AA changes * Bump AbstractAlgebra compat --- Project.toml | 2 +- src/Sparse/Row.jl | 184 +++++++++++++++++++++++++++++++++++++++----- src/Sparse/ZZRow.jl | 4 +- test/Sparse/Row.jl | 36 +++++++++ 4 files changed, 205 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index c15a66b789..326f904601 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ GAPExt = "GAP" PolymakeExt = "Polymake" [compat] -AbstractAlgebra = "^0.43.1" +AbstractAlgebra = "^0.43.10" Dates = "1.6" Distributed = "1.6" GAP = "0.9.6, 0.10, 0.11, 0.12" diff --git a/src/Sparse/Row.jl b/src/Sparse/Row.jl index 96c4eb82d0..653ba87d2c 100644 --- a/src/Sparse/Row.jl +++ b/src/Sparse/Row.jl @@ -105,6 +105,14 @@ function Base.empty!(A::SRow) return A end +function Base.empty(A::SRow) + return sparse_row(base_ring(A)) +end + +function zero(A::SRow) + return empty(A) +end + function swap!(A::SRow, B::SRow) A.pos, B.pos = B.pos, A.pos A.values, B.values = B.values, A.values @@ -447,6 +455,7 @@ end # Inplace scaling # ################################################################################ + @doc raw""" scale_row!(a::SRow, b::NCRingElem) -> SRow @@ -454,8 +463,9 @@ Returns the (left) product of $b \times a$ and reassigns the value of $a$ to thi For rows, the standard multiplication is from the left. """ function scale_row!(a::SRow{T}, b::T) where T - @assert !iszero(b) - if isone(b) + if iszero(b) + return empty!(a) + elseif isone(b) return a end i = 1 @@ -465,20 +475,23 @@ function scale_row!(a::SRow{T}, b::T) where T deleteat!(a.values, i) deleteat!(a.pos, i) else - i += 1 + i += 1 end end return a end +scale_row!(a::SRow, b) = scale_row!(a, base_ring(a)(b)) + @doc raw""" scale_row_right!(a::SRow, b::NCRingElem) -> SRow Returns the (right) product of $a \times b$ and modifies $a$ to this product. """ function scale_row_right!(a::SRow{T}, b::T) where T - @assert !iszero(b) - if isone(b) + if iszero(b) + return empty!(a) + elseif isone(b) return a end i = 1 @@ -488,16 +501,20 @@ function scale_row_right!(a::SRow{T}, b::T) where T deleteat!(a.values, i) deleteat!(a.pos, i) else - i += 1 + i += 1 end end return a end +scale_row_right!(a::SRow, b) = scale_row_right!(a, base_ring(a)(b)) + function scale_row_left!(a::SRow{T}, b::T) where T return scale_row!(a,b) end +scale_row_left!(a::SRow, b) = scale_row_left!(a, base_ring(a)(b)) + ################################################################################ # # Addition @@ -506,9 +523,9 @@ end function +(A::SRow{T}, B::SRow{T}) where T if length(A.values) == 0 - return B + return deepcopy(B) elseif length(B.values) == 0 - return A + return deepcopy(A) end return add_scaled_row(A, B, one(base_ring(A))) end @@ -516,12 +533,12 @@ end function -(A::SRow{T}, B::SRow{T}) where T if length(A) == 0 if length(B) == 0 - return A + return deepcopy(A) else - return add_scaled_row(B, A, base_ring(B)(-1)) + return add_scaled_row(B, A, -1) end end - return add_scaled_row(B, A, base_ring(A)(-1)) + return add_scaled_row(B, A, -1) end function -(A::SRow{T}) where {T} @@ -683,10 +700,10 @@ end Returns the row $c A + B$. """ -add_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_scaled_row!(a, deepcopy(b), c) +add_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, deepcopy(b), c) -add_left_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_left_scaled_row!(a, deepcopy(b), c) -add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_row!(a, deepcopy(b), c) +add_left_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_left_scaled_row!(a, deepcopy(b), c) +add_right_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_right_scaled_row!(a, deepcopy(b), c) @@ -696,7 +713,9 @@ add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_ Adds the left scaled row $c A$ to $B$. """ function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(true)) where {T, left_side} - @assert a !== b + if a === b + a = deepcopy(a) + end i = 1 j = 1 t = base_ring(a)() @@ -735,17 +754,144 @@ function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(tr return b end +add_scaled_row!(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, b, base_ring(a)(c)) + +add_scaled_row!(a::SRow{T}, b::SRow{T}, c, side::Val) where {T} = add_scaled_row!(a, b, base_ring(a)(c), side) + # ignore tmp argument -add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, tmp::SRow{T}) where T = add_scaled_row!(a, b, c) +add_scaled_row!(a::SRow{T}, b::SRow{T}, c, tmp::SRow{T}) where T = add_scaled_row!(a, b, c) -add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c) +add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c) @doc raw""" add_right_scaled_row!(A::SRow{T}, B::SRow{T}, c::T) -> SRow{T} -Return the right scaled row $c A$ to $B$ by changing $B$ in place. +Return the right scaled row $A c$ to $B$ by changing $B$ in place. """ -add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c, Val(false)) +add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c, Val(false)) + + +################################################################################ +# +# Mutating arithmetics +# +################################################################################ + +function zero!(z::SRow) + return empty!(z) +end + +function neg!(z::SRow{T}, x::SRow{T}) where T + if z === x + return neg!(x) + end + swap!(z, -x) + return z +end + +function neg!(z::SRow) + for i in 1:length(z) + z.values[i] = neg!(z.values[i]) + end + return z +end + +function add!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T + if z === x + return add!(x, y) + elseif z === y + return add!(y, x) + end + swap!(z, x + y) + return z +end + +function add!(z::SRow{T}, x::SRow{T}) where T + if z === x + return scale_row!(z, 2) + end + return add_scaled_row!(x, z, one(base_ring(x))) +end + +function sub!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T + if z === x + return sub!(x, y) + elseif z === y + return neg!(sub!(y, x)) + end + swap!(z, x - y) + return z +end + +function sub!(z::SRow{T}, x::SRow{T}) where T + if z === x + return empty!(z) + end + return add_scaled_row!(x, z, -1) +end + +function mul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T + error("Not implemented") +end + +function mul!(z::SRow{T}, x::SRow{T}, c) where T + if z === x + return scale_row_right!(x, c) + end + swap!(z, x * c) + return z +end + +function mul!(z::SRow{T}, c, y::SRow{T}) where T + if z === y + return scale_row_left!(y, c) + end + swap!(z, c * y) + return z +end + +function addmul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T + error("Not implemented") +end + +function addmul!(z::SRow{T}, x::SRow{T}, y) where T + if z === x + return scale_row_right!(x, y+1) + end + return add_right_scaled_row!(x, z, y) +end + +function addmul!(z::SRow{T}, x, y::SRow{T}) where T + if z === x + return scale_row_left!(y, x+1) + end + return add_left_scaled_row!(y, z, x) +end + +function submul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T + error("Not implemented") +end + +function submul!(z::SRow{T}, x::SRow{T}, y) where T + if z === x + return scale_row_right!(x, -y+1) + end + return add_right_scaled_row!(x, z, -y) +end + +function submul!(z::SRow{T}, x, y::SRow{T}) where T + if z === x + return scale_row_left!(y, -x+1) + end + return add_left_scaled_row!(y, z, -x) +end + + +# ignore temp variable +addmul!(z::SRow{T}, x::SRow{T}, y, t) where T = addmul!(z, x, y) +addmul!(z::SRow{T}, x, y::SRow{T}, t) where T = addmul!(z, x, y) +submul!(z::SRow{T}, x::SRow{T}, y, t) where T = submul!(z, x, y) +submul!(z::SRow{T}, x, y::SRow{T}, t) where T = submul!(z, x, y) ################################################################################ diff --git a/src/Sparse/ZZRow.jl b/src/Sparse/ZZRow.jl index d4cfc30f04..a9073f5a23 100644 --- a/src/Sparse/ZZRow.jl +++ b/src/Sparse/ZZRow.jl @@ -276,7 +276,6 @@ end function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ)) empty!(sr) - @assert c != 0 n = ZZRingElem() pi = 1 pj = 1 @@ -323,6 +322,9 @@ function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingEle end function add_scaled_row!(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ)) + if iszero(c) + return Aj + end _t = sr sr = add_scaled_row(Ai, Aj, c, sr) @assert _t === sr diff --git a/test/Sparse/Row.jl b/test/Sparse/Row.jl index d104b80cca..14c3fb12bf 100644 --- a/test/Sparse/Row.jl +++ b/test/Sparse/Row.jl @@ -204,4 +204,40 @@ B = sparse_row(F,[1],[y]) C = add_scaled_row(A,B,F(1)) @test C == A+B + + # mutating arithmetic + randcoeff() = begin + n = rand((1,1,1,2,5,7,15)) + return rand(-2^n:2^n) + end + Main.equality(A::SRow, B::SRow) = A == B + @testset "mutating arithmetic; R = $R" for R in (ZZ, QQ) + for _ in 1:10 + maxind_A = rand(0:10) + inds_A = Hecke.Random.randsubseq(1:maxind_A, rand()) + vals_A = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_A)] + A = sparse_row(R, inds_A, vals_A) + + maxind_B = rand(0:10) + inds_B = Hecke.Random.randsubseq(1:maxind_B, rand()) + vals_B = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_B)] + B = sparse_row(R, inds_B, vals_B) + + test_mutating_op_like_zero(zero, zero!, A) + + test_mutating_op_like_neg(-, neg!, A) + + test_mutating_op_like_add(+, add!, A, B) + test_mutating_op_like_add(-, sub!, A, B) + test_mutating_op_like_add(*, mul!, A, randcoeff(), SRow) + test_mutating_op_like_add(*, mul!, randcoeff(), A, SRow) + test_mutating_op_like_add(*, mul!, A, ZZ(randcoeff()), SRow) + test_mutating_op_like_add(*, mul!, ZZ(randcoeff()), A, SRow) + + test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, B, randcoeff(), SRow) + test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, randcoeff(), B, SRow) + test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, B, randcoeff(), SRow) + test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, randcoeff(), B, SRow) + end + end end