Skip to content

Commit

Permalink
fix planar error (#414)
Browse files Browse the repository at this point in the history
* fix planar error

* fix ambiguity

* try new

* fix
  • Loading branch information
prbzrg authored May 14, 2024
1 parent 355fa2c commit b45e114
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/layers/mul_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer)
end

@inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple)
return Lux.apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st
m.activation.(Octavian.matmul(ps.weight, x)), st
end
17 changes: 8 additions & 9 deletions src/layers/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer)
end

@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)),
st
ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
end

@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st
ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st
end

@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple)
ps.u * Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st
ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st
end

@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple)
ps.u * Lux.apply_activation(m.activation, transpose(ps.w) * z), st
ps.u * m.activation.(transpose(ps.w) * z), st
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple)
Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), st
m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st
end

@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple)
Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st
m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple)
Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st
m.activation.(LinearAlgebra.dot(ps.w, z)), st
end

@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple)
Lux.apply_activation(m.activation, transpose(ps.w) * z), st
m.activation.(transpose(ps.w) * z), st
end

0 comments on commit b45e114

Please sign in to comment.