From fb7b74794d0da7a7745b960d14cbc012c1141db2 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 14 May 2024 22:31:02 +0330 Subject: [PATCH] fix planar error --- src/ContinuousNormalizingFlows.jl | 3 ++- src/layers/base_layer.jl | 11 +++++++++++ src/layers/mul_layer.jl | 2 +- src/layers/planar_layer.jl | 17 ++++++++--------- 4 files changed, 22 insertions(+), 11 deletions(-) create mode 100644 src/layers/base_layer.jl diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 55b54d46..91cc2305 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -58,9 +58,10 @@ export construct, PlanarLayer, MulLayer +include(joinpath("layers", "base_layer.jl")) include(joinpath("layers", "cond_layer.jl")) -include(joinpath("layers", "planar_layer.jl")) include(joinpath("layers", "mul_layer.jl")) +include(joinpath("layers", "planar_layer.jl")) include("types.jl") diff --git a/src/layers/base_layer.jl b/src/layers/base_layer.jl new file mode 100644 index 00000000..490710bd --- /dev/null +++ b/src/layers/base_layer.jl @@ -0,0 +1,11 @@ +@inline function apply_act(::typeof(identity), x::Any) + x +end + +@inline function apply_act(activation::Any, x::Number) + activation(x) +end + +@inline function apply_act(activation::Any, x::AbstractArray) + activation.(x) +end diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl index 983505cc..ea9da6de 100644 --- a/src/layers/mul_layer.jl +++ b/src/layers/mul_layer.jl @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer) end @inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) - return Lux.apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st + apply_act(m.activation, Octavian.matmul(ps.weight, x)), st end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index 3f43412d..f695e3c3 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -71,34 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer) end @inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), - st + ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + ps.u * apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st end @inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st end @inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, transpose(ps.w) * z), st + ps.u * apply_act(m.activation, transpose(ps.w) * z), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), st + apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, transpose(ps.w) * z), st + apply_act(m.activation, transpose(ps.w) * z), st end