From fb7b74794d0da7a7745b960d14cbc012c1141db2 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 14 May 2024 22:31:02 +0330 Subject: [PATCH 1/4] fix planar error --- src/ContinuousNormalizingFlows.jl | 3 ++- src/layers/base_layer.jl | 11 +++++++++++ src/layers/mul_layer.jl | 2 +- src/layers/planar_layer.jl | 17 ++++++++--------- 4 files changed, 22 insertions(+), 11 deletions(-) create mode 100644 src/layers/base_layer.jl diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 55b54d46..91cc2305 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -58,9 +58,10 @@ export construct, PlanarLayer, MulLayer +include(joinpath("layers", "base_layer.jl")) include(joinpath("layers", "cond_layer.jl")) -include(joinpath("layers", "planar_layer.jl")) include(joinpath("layers", "mul_layer.jl")) +include(joinpath("layers", "planar_layer.jl")) include("types.jl") diff --git a/src/layers/base_layer.jl b/src/layers/base_layer.jl new file mode 100644 index 00000000..490710bd --- /dev/null +++ b/src/layers/base_layer.jl @@ -0,0 +1,11 @@ +@inline function apply_act(::typeof(identity), x::Any) + x +end + +@inline function apply_act(activation::Any, x::Number) + activation(x) +end + +@inline function apply_act(activation::Any, x::AbstractArray) + activation.(x) +end diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl index 983505cc..ea9da6de 100644 --- a/src/layers/mul_layer.jl +++ b/src/layers/mul_layer.jl @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer) end @inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) - return Lux.apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st + apply_act(m.activation, Octavian.matmul(ps.weight, x)), st end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index 3f43412d..f695e3c3 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -71,34 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer) end @inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), - st + ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + ps.u * apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st end @inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st end @inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * Lux.apply_activation(m.activation, transpose(ps.w) * z), st + ps.u * apply_act(m.activation, transpose(ps.w) * z), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, LinearAlgebra.dot(ps.w, z), only(ps.b)), st + apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_bias_activation(m.activation, transpose(ps.w) * z, only(ps.b)), st + apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, LinearAlgebra.dot(ps.w, z)), st + apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) - Lux.apply_activation(m.activation, transpose(ps.w) * z), st + apply_act(m.activation, transpose(ps.w) * z), st end From 2fbd26ec872e510dd71529ac03d3a93045898a33 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 14 May 2024 23:28:03 +0330 Subject: [PATCH 2/4] fix ambiguity --- src/layers/base_layer.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/base_layer.jl b/src/layers/base_layer.jl index 490710bd..3586ea2a 100644 --- a/src/layers/base_layer.jl +++ b/src/layers/base_layer.jl @@ -1,11 +1,11 @@ -@inline function apply_act(::typeof(identity), x::Any) +@inline function apply_act(::typeof(identity), x::Union{Number, AbstractArray}) x end -@inline function apply_act(activation::Any, x::Number) +@inline function apply_act(activation::Function, x::Number) activation(x) end -@inline function apply_act(activation::Any, x::AbstractArray) +@inline function apply_act(activation::Function, x::AbstractArray) activation.(x) end From cfb6c13da0645271543ef8741153110ed5d1d73f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 15 May 2024 00:32:21 +0330 Subject: [PATCH 3/4] try new --- src/ContinuousNormalizingFlows.jl | 2 +- src/layers/base_layer.jl | 11 ----------- src/layers/mul_layer.jl | 2 +- src/layers/planar_layer.jl | 16 ++++++++-------- 4 files changed, 10 insertions(+), 21 deletions(-) delete mode 100644 src/layers/base_layer.jl diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 91cc2305..c18c77fd 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -60,8 +60,8 @@ export construct, include(joinpath("layers", "base_layer.jl")) include(joinpath("layers", "cond_layer.jl")) -include(joinpath("layers", "mul_layer.jl")) include(joinpath("layers", "planar_layer.jl")) +include(joinpath("layers", "mul_layer.jl")) include("types.jl") diff --git a/src/layers/base_layer.jl b/src/layers/base_layer.jl deleted file mode 100644 index 3586ea2a..00000000 --- a/src/layers/base_layer.jl +++ /dev/null @@ -1,11 +0,0 @@ -@inline function apply_act(::typeof(identity), x::Union{Number, AbstractArray}) - x -end - -@inline function apply_act(activation::Function, x::Number) - activation(x) -end - -@inline function apply_act(activation::Function, x::AbstractArray) - activation.(x) -end diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl index ea9da6de..02804a40 100644 --- a/src/layers/mul_layer.jl +++ b/src/layers/mul_layer.jl @@ -31,5 +31,5 @@ function LuxCore.outputsize(m::MulLayer) end @inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) - apply_act(m.activation, Octavian.matmul(ps.weight, x)), st + m.activation.(Octavian.matmul(ps.weight, x)), st end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index f695e3c3..6b087abe 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -71,33 +71,33 @@ function LuxCore.outputsize(m::PlanarLayer) end @inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st + ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st end @inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st end @inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * apply_act(m.activation, transpose(ps.w) * z), st + ps.u * m.activation.(transpose(ps.w) * z), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) - apply_act(m.activation, LinearAlgebra.dot(ps.w, z) + only(ps.b)), st + m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st end @inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) - apply_act(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st + m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) - apply_act(m.activation, LinearAlgebra.dot(ps.w, z)), st + m.activation.(LinearAlgebra.dot(ps.w, z)), st end @inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) - apply_act(m.activation, transpose(ps.w) * z), st + m.activation.(transpose(ps.w) * z), st end From 539be3c5ae069ae582aea805461d2af556bdcf15 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 15 May 2024 00:35:11 +0330 Subject: [PATCH 4/4] fix --- src/ContinuousNormalizingFlows.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index c18c77fd..55b54d46 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -58,7 +58,6 @@ export construct, PlanarLayer, MulLayer -include(joinpath("layers", "base_layer.jl")) include(joinpath("layers", "cond_layer.jl")) include(joinpath("layers", "planar_layer.jl")) include(joinpath("layers", "mul_layer.jl"))