Skip to content

Commit

Permalink
fix: update +(Ptr, StaticInt) to __add
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 4, 2024
1 parent 9f94ffc commit 4277b26
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 22 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ jobs:
os:
- ubuntu-latest
version:
- '1.6'
- '1'
steps:
- uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ ManualMemory = "0.1.8"
Polyester = "0.4, 0.5, 0.6, 0.7"
SIMDTypes = "0.1"
SLEEFPirates = "0.6"
Static = "0.8.4, 1"
Static = "1"
StaticArrayInterface = "1"
StaticArrays = "1"
StrideArraysCore = "0.4.7, 0.5"
StrideArraysCore = "0.5"
UnPack = "1"
VectorizationBase = "0.21.40"
VectorizedRNG = "0.2.13"
julia = "1.6"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
4 changes: 2 additions & 2 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ function valgrad_noloss(
@inbounds @simd ivdep for i in eachindex(parg)
parg2[i] = parg[i]
end
pm += aoff
pm = __add(pm, aoff)
g = PtrArray(Ptr{T}(pm), (glen,))
pm += goff
pm = __add(pm, goff)
l, pbl =
chain_valgrad_pullback!(pointer(g), parg2, layers, pointer(params), pm)
end
Expand Down
2 changes: 1 addition & 1 deletion src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function valgrad_layer!(
VectorizedRNG.storestate!(rng, state)
end # GC preserve

pg, x, p, align(pu + ((static(7) + N) >>> static(3)))
pg, x, p, align(__add(pu, ((static(7) + N) >>> static(3))))
end

function pullback_arg!(
Expand Down
4 changes: 2 additions & 2 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ function valgrad_core(
) where {T}
@unpack layers = c
g = PtrArray(Ptr{T}(pu), (glen,))
l = unsafe_valgrad!(c, pu + align(glen * static_sizeof(T)), g, params, arg)
l = unsafe_valgrad!(c, __add(pu, align(glen * static_sizeof(T))), g, params, arg)
Base.FastMath.add_fast(
l,
apply_penalty!(g, getpenalty(c), params, static_size(arg))
Expand All @@ -838,7 +838,7 @@ function valgrad_core_sarray(
l = Base.FastMath.add_fast(
unsafe_valgrad!(
c,
pu + align(static(L) * static_sizeof(T)),
__add(pu, align(static(L) * static_sizeof(T))),
g,
params,
arg
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,7 @@ function _add_memory(t::Tuple, p)
(A, B...)
end
_add_memory(::Nothing, p) = nothing

__add(x, y) = x + y
__add(x::Ptr, ::StaticInt{N}) where {N} = x + N
__add(::StaticInt{N}, y::Ptr) where {N} = y + N
16 changes: 3 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
using SimpleChains
using Test, Aqua, ForwardDiff, Zygote, ChainRules, Random
@static if VERSION >= v"1.9"
using JET: @test_opt
else
macro test_opt(ex)
nothing
end
end
using JET: @test_opt

countallocations!(g, sc, x, p) = @allocated valgrad!(g, sc, x, p)
dual(x::T) where {T} = ForwardDiff.Dual(x, 4randn(T), 4randn(T), 4randn(T))
Expand Down Expand Up @@ -84,12 +78,8 @@ InteractiveUtils.versioninfo(; verbose = true)
SquaredLoss"""

@test sprint((io, t) -> show(io, t), sc) == print_str0
if VERSION >= v"1.6"
@test sprint((io, t) -> show(io, t), scflp) == print_str1
else
# typename doesn't work on 1.5
@test_broken sprint((io, t) -> show(io, t), scflp) == print_str1
end
@test sprint((io, t) -> show(io, t), scflp) == print_str1

p = SimpleChains.init_params(scflp, T; rng = Random.default_rng())
g = similar(p)
let sc = SimpleChains.remove_loss(sc)
Expand Down

0 comments on commit 4277b26

Please sign in to comment.