From d1603188d5806c06cb58bd8ebbd1c6e5b6be605e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 27 Mar 2024 04:19:43 +0330 Subject: [PATCH] Cleaning & add `MulLayer` (#390) * cleaning * paramtype * fix * add export --- Project.toml | 2 + .../ContinuousNormalizingFlowsCUDAExt.jl | 2 +- src/ContinuousNormalizingFlows.jl | 17 +- src/base.jl | 226 -------- src/base_icnf.jl | 275 +++++++++- src/{ => cores}/core.jl | 42 +- src/{ => cores}/core_cond_icnf.jl | 508 +++++++++--------- src/{ => cores}/core_icnf.jl | 480 ++++++++--------- src/icnf.jl | 20 +- src/{ => layers}/cond_layer.jl | 20 +- src/layers/mul_layer.jl | 37 ++ src/layers/planar_layer.jl | 105 ++++ src/planar_layer.jl | 163 ------ src/types.jl | 11 +- src/utils.jl | 6 +- test/call_tests.jl | 14 +- test/fit_tests.jl | 14 +- 17 files changed, 958 insertions(+), 984 deletions(-) delete mode 100644 src/base.jl rename src/{ => cores}/core.jl (79%) rename src/{ => cores}/core_cond_icnf.jl (96%) rename src/{ => cores}/core_icnf.jl (96%) rename src/{ => layers}/cond_layer.jl (68%) create mode 100644 src/layers/mul_layer.jl create mode 100644 src/layers/planar_layer.jl delete mode 100644 src/planar_layer.jl diff --git a/Project.toml b/Project.toml index 15fbd56c..56da3f16 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" @@ -59,6 +60,7 @@ MLJBase = "1" MLJModelInterface = "1" MLUtils = "0.4" NNlib = "0.9" +Octavian = "0.3.27" Optimisers = "0.3" Optimization = "3.15" OptimizationOptimisers = "0.2" diff --git a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl index d4b5060f..9f5fc546 100644 --- a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl +++ b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl @@ -9,7 +9,7 @@ end @inline function ContinuousNormalizingFlows.base_AT( ::CUDALibs, - ::ContinuousNormalizingFlows.AbstractFlows{T}, + ::ContinuousNormalizingFlows.AbstractICNF{T}, dims..., ) where {T <: AbstractFloat} CuArray{T}(undef, dims...) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 344a1d5d..e5c9e641 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -18,6 +18,7 @@ using AbstractDifferentiation, MLJModelInterface, MLUtils, NNlib, + Octavian, Optimisers, Optimization, OptimizationOptimisers, @@ -31,22 +32,22 @@ using AbstractDifferentiation, Statistics, Zygote -include("types.jl") +include(joinpath("layers", "cond_layer.jl")) +include(joinpath("layers", "planar_layer.jl")) +include(joinpath("layers", "mul_layer.jl")) -include("planar_layer.jl") -include("cond_layer.jl") +include("types.jl") -include("base.jl") include("base_icnf.jl") -include("core.jl") -include("core_icnf.jl") -include("core_cond_icnf.jl") - include("icnf.jl") include("utils.jl") +include(joinpath("cores", "core.jl")) +include(joinpath("cores", "core_icnf.jl")) +include(joinpath("cores", "core_cond_icnf.jl")) + """ Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia """ diff --git a/src/base.jl b/src/base.jl deleted file mode 100644 index 2e1a1e40..00000000 --- a/src/base.jl +++ /dev/null @@ -1,226 +0,0 @@ -export construct - -function construct( - aicnf::Type{<:AbstractFlows}, - nn::LuxCore.AbstractExplicitLayer, - nvars::Int, - naugmented::Int = 0; - data_type::Type{<:AbstractFloat} = Float32, - compute_mode::Type{<:ComputeMode} = ADVecJacVectorMode, - inplace::Bool = false, - cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, - resource::AbstractResource = CPU1(), - basedist::Distribution = MvNormal( - Zeros{data_type}(nvars + naugmented), - Eye{data_type}(nvars + naugmented), - ), - tspan::NTuple{2} = (zero(data_type), one(data_type)), - steer_rate::AbstractFloat = zero(data_type), - epsdist::Distribution = MvNormal( - Zeros{data_type}(nvars + naugmented), - Eye{data_type}(nvars + naugmented), - ), - differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(), - autodiff_backend::ADTypes.AbstractADType = ifelse( - compute_mode <: SDJacVecMatrixMode, - AutoForwardDiff(), - AutoZygote(), - ), - sol_kwargs::NamedTuple = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - ), - rng::AbstractRNG = rng_AT(resource), - λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1e-2) - else - zero(data_type) - end, - λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1e-2) - else - zero(data_type) - end, - λ₃::AbstractFloat = zero(data_type), -) - steerdist = Uniform{data_type}(-steer_rate, steer_rate) - - ICNF{ - data_type, - compute_mode, - inplace, - cond, - !iszero(naugmented), - !iszero(steer_rate), - !iszero(λ₁), - !iszero(λ₂), - !iszero(λ₃), - typeof(nn), - typeof(nvars), - typeof(resource), - typeof(basedist), - typeof(tspan), - typeof(steerdist), - typeof(epsdist), - typeof(differentiation_backend), - typeof(autodiff_backend), - typeof(sol_kwargs), - typeof(rng), - }( - nn, - nvars, - naugmented, - resource, - basedist, - tspan, - steerdist, - epsdist, - differentiation_backend, - autodiff_backend, - sol_kwargs, - rng, - λ₁, - λ₂, - λ₃, - ) -end - -@inline function n_augment(::AbstractFlows, ::Mode) - 0 -end - -# pretty-printing - -function Base.show(io::IO, icnf::AbstractFlows) - print( - io, - typeof(icnf), - "\n\tNumber of Variables: ", - icnf.nvars, - "\n\tNumber of Augmentations: ", - n_augment_input(icnf), - "\n\tTime Span: ", - icnf.tspan, - ) -end - -@inline function n_augment_input( - icnf::AbstractFlows{<:AbstractFloat, <:ComputeMode, INPLACE, COND, true}, -) where {INPLACE, COND} - icnf.naugmented -end - -@inline function n_augment_input(::AbstractFlows) - 0 -end - -@inline function steer_tspan( - icnf::AbstractFlows{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, - ::TrainMode, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED} - t₀, t₁ = icnf.tspan - Δt = abs(t₁ - t₀) - r = convert(T, rand(icnf.rng, icnf.steerdist)) - t₁_new = muladd(Δt, r, t₁) - (t₀, t₁_new) -end - -@inline function steer_tspan(icnf::AbstractFlows, ::Mode) - icnf.tspan -end - -@inline function rng_AT(::AbstractResource) - Random.default_rng() -end - -@inline function base_AT( - ::AbstractResource, - ::AbstractFlows{T}, - dims..., -) where {T <: AbstractFloat} - Array{T}(undef, dims...) -end - -@non_differentiable base_AT(::Any...) - -function inference_sol( - icnf::AbstractFlows{T, <:VectorMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - n_aug = n_augment(icnf, mode) - sol = solve(prob; icnf.sol_kwargs...) - fsol = get_fsol(sol) - z = fsol[begin:(end - n_aug - 1)] - Δlogp = fsol[(end - n_aug)] - augs = fsol[(end - n_aug + 1):end] - logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) - logp̂x = logpz - Δlogp - Ȧ = if (NORM_Z_AUG && AUGMENTED) - n_aug_input = n_augment_input(icnf) - z_aug = z[(end - n_aug_input + 1):end] - norm(z_aug) - else - zero(T) - end - (logp̂x, vcat(augs, Ȧ)) -end - -function inference_sol( - icnf::AbstractFlows{T, <:MatrixMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - n_aug = n_augment(icnf, mode) - sol = solve(prob; icnf.sol_kwargs...) - fsol = get_fsol(sol) - z = fsol[begin:(end - n_aug - 1), :] - Δlogp = fsol[(end - n_aug), :] - augs = fsol[(end - n_aug + 1):end, :] - logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) - logp̂x = logpz - Δlogp - Ȧ = transpose(if (NORM_Z_AUG && AUGMENTED) - n_aug_input = n_augment_input(icnf) - z_aug = z[(end - n_aug_input + 1):end, :] - norm.(eachcol(z_aug)) - else - zrs_aug = similar(augs, size(augs, 2)) - @ignore_derivatives fill!(zrs_aug, zero(T)) - zrs_aug - end) - (logp̂x, eachrow(vcat(augs, Ȧ))) -end - -function generate_sol( - icnf::AbstractFlows{T, <:VectorMode, INPLACE}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - sol = solve(prob; icnf.sol_kwargs...) - fsol = get_fsol(sol) - z = fsol[begin:(end - n_aug_input - n_aug - 1)] - z -end - -function generate_sol( - icnf::AbstractFlows{T, <:MatrixMode, INPLACE}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - sol = solve(prob; icnf.sol_kwargs...) - fsol = get_fsol(sol) - z = fsol[begin:(end - n_aug_input - n_aug - 1), :] - z -end - -@inline function get_fsol(sol::SciMLBase.AbstractODESolution) - last(sol.u) -end - -@inline function get_fsol(sol::AbstractArray{T, N}) where {T, N} - selectdim(sol, N, lastindex(sol, N)) -end diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 6c75a337..211df9ae 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -1,11 +1,236 @@ -export inference, generate, loss +export construct, inference, generate, loss + +function construct( + aicnf::Type{<:AbstractICNF}, + nn::LuxCore.AbstractExplicitLayer, + nvars::Int, + naugmented::Int = 0; + data_type::Type{<:AbstractFloat} = Float32, + compute_mode::Type{<:ComputeMode} = ADVecJacVectorMode, + inplace::Bool = false, + cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, + resource::AbstractResource = CPU1(), + basedist::Distribution = MvNormal( + Zeros{data_type}(nvars + naugmented), + Eye{data_type}(nvars + naugmented), + ), + tspan::NTuple{2} = (zero(data_type), one(data_type)), + steer_rate::AbstractFloat = zero(data_type), + epsdist::Distribution = MvNormal( + Zeros{data_type}(nvars + naugmented), + Eye{data_type}(nvars + naugmented), + ), + differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(), + autodiff_backend::ADTypes.AbstractADType = ifelse( + compute_mode <: SDJacVecMatrixMode, + AutoForwardDiff(), + AutoZygote(), + ), + sol_kwargs::NamedTuple = ( + save_everystep = false, + alg = Tsit5(; thread = OrdinaryDiffEq.True()), + ), + rng::AbstractRNG = rng_AT(resource), + λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} + convert(data_type, 1e-2) + else + zero(data_type) + end, + λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} + convert(data_type, 1e-2) + else + zero(data_type) + end, + λ₃::AbstractFloat = zero(data_type), +) + steerdist = Uniform{data_type}(-steer_rate, steer_rate) + + ICNF{ + data_type, + compute_mode, + inplace, + cond, + !iszero(naugmented), + !iszero(steer_rate), + !iszero(λ₁), + !iszero(λ₂), + !iszero(λ₃), + typeof(nn), + typeof(nvars), + typeof(resource), + typeof(basedist), + typeof(tspan), + typeof(steerdist), + typeof(epsdist), + typeof(differentiation_backend), + typeof(autodiff_backend), + typeof(sol_kwargs), + typeof(rng), + }( + nn, + nvars, + naugmented, + resource, + basedist, + tspan, + steerdist, + epsdist, + differentiation_backend, + autodiff_backend, + sol_kwargs, + rng, + λ₁, + λ₂, + λ₃, + ) +end + +@inline function n_augment(::AbstractICNF, ::Mode) + 0 +end + +# pretty-printing + +function Base.show(io::IO, icnf::AbstractICNF) + print( + io, + typeof(icnf), + "\n\tNumber of Variables: ", + icnf.nvars, + "\n\tNumber of Augmentations: ", + n_augment_input(icnf), + "\n\tTime Span: ", + icnf.tspan, + ) +end + +@inline function n_augment_input( + icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, true}, +) where {INPLACE, COND} + icnf.naugmented +end + +@inline function n_augment_input(::AbstractICNF) + 0 +end + +@inline function steer_tspan( + icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, + ::TrainMode, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED} + t₀, t₁ = icnf.tspan + Δt = abs(t₁ - t₀) + r = convert(T, rand(icnf.rng, icnf.steerdist)) + t₁_new = muladd(Δt, r, t₁) + (t₀, t₁_new) +end + +@inline function steer_tspan(icnf::AbstractICNF, ::Mode) + icnf.tspan +end + +@inline function rng_AT(::AbstractResource) + Random.default_rng() +end + +@inline function base_AT( + ::AbstractResource, + ::AbstractICNF{T}, + dims..., +) where {T <: AbstractFloat} + Array{T}(undef, dims...) +end + +@non_differentiable base_AT(::Any...) + +function inference_sol( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} + n_aug = n_augment(icnf, mode) + sol = solve(prob; icnf.sol_kwargs...) + fsol = get_fsol(sol) + z = fsol[begin:(end - n_aug - 1)] + Δlogp = fsol[(end - n_aug)] + augs = fsol[(end - n_aug + 1):end] + logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp + Ȧ = if (NORM_Z_AUG && AUGMENTED) + n_aug_input = n_augment_input(icnf) + z_aug = z[(end - n_aug_input + 1):end] + norm(z_aug) + else + zero(T) + end + (logp̂x, vcat(augs, Ȧ)) +end + +function inference_sol( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} + n_aug = n_augment(icnf, mode) + sol = solve(prob; icnf.sol_kwargs...) + fsol = get_fsol(sol) + z = fsol[begin:(end - n_aug - 1), :] + Δlogp = fsol[(end - n_aug), :] + augs = fsol[(end - n_aug + 1):end, :] + logpz = oftype(Δlogp, logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp + Ȧ = transpose(if (NORM_Z_AUG && AUGMENTED) + n_aug_input = n_augment_input(icnf) + z_aug = z[(end - n_aug_input + 1):end, :] + norm.(eachcol(z_aug)) + else + zrs_aug = similar(augs, size(augs, 2)) + @ignore_derivatives fill!(zrs_aug, zero(T)) + zrs_aug + end) + (logp̂x, eachrow(vcat(augs, Ȧ))) +end + +function generate_sol( + icnf::AbstractICNF{T, <:VectorMode, INPLACE}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + sol = solve(prob; icnf.sol_kwargs...) + fsol = get_fsol(sol) + z = fsol[begin:(end - n_aug_input - n_aug - 1)] + z +end + +function generate_sol( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + sol = solve(prob; icnf.sol_kwargs...) + fsol = get_fsol(sol) + z = fsol[begin:(end - n_aug_input - n_aug - 1), :] + z +end + +@inline function get_fsol(sol::SciMLBase.AbstractODESolution) + last(sol.u) +end + +@inline function get_fsol(sol::AbstractArray{T, N}) where {T, N} + selectdim(sol, N, lastindex(sol, N)) +end function inference_prob( icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, mode::Mode, xs::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -36,7 +261,7 @@ function inference_prob( xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -66,7 +291,7 @@ function inference_prob( mode::Mode, xs::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -97,7 +322,7 @@ function inference_prob( xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -126,7 +351,7 @@ function generate_prob( icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, mode::Mode, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -158,7 +383,7 @@ function generate_prob( mode::Mode, ys::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) n_aug_input = n_augment_input(icnf) @@ -189,7 +414,7 @@ function generate_prob( icnf::AbstractICNF{T, <:MatrixMode, INPLACE, false}, mode::Mode, ps::Any, - st::Any, + st::NamedTuple, n::Int, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) @@ -222,7 +447,7 @@ function generate_prob( mode::Mode, ys::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, n::Int, ) where {T <: AbstractFloat, INPLACE} n_aug = n_augment(icnf, mode) @@ -255,7 +480,7 @@ end mode::Mode, xs::AbstractVecOrMat{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ps, st)) end @@ -266,7 +491,7 @@ end xs::AbstractVecOrMat{<:Real}, ys::AbstractVecOrMat{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ys, ps, st)) end @@ -275,7 +500,7 @@ end icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, mode::Mode, ps::Any, - st::Any, + st::NamedTuple, ) generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st)) end @@ -285,7 +510,7 @@ end mode::Mode, ys::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st)) end @@ -294,7 +519,7 @@ end icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, mode::Mode, ps::Any, - st::Any, + st::NamedTuple, n::Int, ) generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st, n)) @@ -305,7 +530,7 @@ end mode::Mode, ys::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, n::Int, ) generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st, n)) @@ -316,7 +541,7 @@ end mode::Mode, xs::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) -first(inference(icnf, mode, xs, ps, st)) end @@ -327,7 +552,7 @@ end xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) -first(inference(icnf, mode, xs, ys, ps, st)) end @@ -337,7 +562,7 @@ end mode::Mode, xs::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) -mean(first(inference(icnf, mode, xs, ps, st))) end @@ -348,24 +573,24 @@ end xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) -mean(first(inference(icnf, mode, xs, ys, ps, st))) end @inline function (icnf::AbstractICNF{T, CM, INPLACE, false})( - xs::Any, + xs::AbstractVecOrMat, ps::Any, - st::Any, + st::NamedTuple, ) where {T, CM, INPLACE} - first(inference(icnf, TrainMode(), xs, ps, st)) + first(inference(icnf, TrainMode(), xs, ps, st)), st end @inline function (icnf::AbstractICNF{T, CM, INPLACE, true})( - xs_ys::Any, + xs_ys::Tuple, ps::Any, - st::Any, + st::NamedTuple, ) where {T, CM, INPLACE} xs, ys = xs_ys - first(inference(icnf, TrainMode(), xs, ys, ps, st)) + first(inference(icnf, TrainMode(), xs, ys, ps, st)), st end diff --git a/src/core.jl b/src/cores/core.jl similarity index 79% rename from src/core.jl rename to src/cores/core.jl index 5122d3e3..c8a35628 100644 --- a/src/core.jl +++ b/src/cores/core.jl @@ -1,21 +1,21 @@ -# SciML interface - -function callback_f(ps, l, ::AbstractFlows, prgr::Progress, itr_n::AbstractArray) - ProgressMeter.next!( - prgr; - showvalues = [ - (:loss_value, l), - (:iteration, only(itr_n)), - (:last_update, Dates.now()), - ], - ) - itr_n[] += one(only(itr_n)) - false -end - -# MLJ interface - -function MLJModelInterface.fitted_params(::MLJICNF, fitresult) - (ps, st) = fitresult - (learned_parameters = ps, states = st) -end +# SciML interface + +function callback_f(ps, l, ::AbstractICNF, prgr::Progress, itr_n::AbstractArray) + ProgressMeter.next!( + prgr; + showvalues = [ + (:loss_value, l), + (:iteration, only(itr_n)), + (:last_update, Dates.now()), + ], + ) + itr_n[] += one(only(itr_n)) + false +end + +# MLJ interface + +function MLJModelInterface.fitted_params(::MLJICNF, fitresult) + (ps, st) = fitresult + (learned_parameters = ps, states = st) +end diff --git a/src/core_cond_icnf.jl b/src/cores/core_cond_icnf.jl similarity index 96% rename from src/core_cond_icnf.jl rename to src/cores/core_cond_icnf.jl index a9030619..8345939e 100644 --- a/src/core_cond_icnf.jl +++ b/src/cores/core_cond_icnf.jl @@ -1,254 +1,254 @@ -export CondICNFModel, CondICNFDist - -# MLJ interface - -mutable struct CondICNFModel <: MLJICNF - m::AbstractICNF - loss::Function - - optimizers::AbstractVector - n_epochs::Int - adtype::ADTypes.AbstractADType - - use_batch::Bool - batch_size::Int - have_callback::Bool - - compute_mode::Type{<:ComputeMode} -end - -function CondICNFModel( - m::AbstractICNF{<:AbstractFloat, CM}, - loss::Function = loss; - optimizers::AbstractVector = Any[Optimisers.Lion(),], - n_epochs::Int = 300, - adtype::ADTypes.AbstractADType = AutoZygote(), - use_batch::Bool = true, - batch_size::Int = 32, - have_callback::Bool = true, -) where {CM <: ComputeMode} - CondICNFModel( - m, - loss, - optimizers, - n_epochs, - adtype, - use_batch, - batch_size, - have_callback, - CM, - ) -end - -function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) - X, Y = XY - x = collect(transpose(MLJModelInterface.matrix(X))) - y = collect(transpose(MLJModelInterface.matrix(Y))) - ps, st = LuxCore.setup(model.m.rng, model.m) - ps = ComponentArray(ps) - if model.m.resource isa CUDALibs - gdev = gpu_device() - x = gdev(x) - y = gdev(y) - ps = gdev(ps) - st = gdev(st) - end - optfunc = OptimizationFunction( - let mm = model.m, md = TrainMode(), st = st - (u, p, xs_, ys_) -> model.loss(mm, md, xs_, ys_, u, st) - end, - model.adtype, - ) - optprob = OptimizationProblem(optfunc, ps) - tst_overall = @timed for opt in model.optimizers - tst_epochs = @timed for ep in 1:(model.n_epochs) - if model.use_batch - if model.compute_mode <: VectorMode - data = DataLoader( - (x, y); - batchsize = -1, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - elseif model.compute_mode <: MatrixMode - data = DataLoader( - (x, y); - batchsize = model.batch_size, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - else - error("Not Implemented") - end - else - data = [(x, y)] - end - optprob_re = remake(optprob; u0 = ps) - if model.have_callback - prgr = Progress( - length(data); - desc = "Fitting (epoch: $ep of $(model.n_epochs)): ", - showspeed = true, - ) - itr_n = ones(Int) - tst_one = @timed res = solve( - optprob_re, - opt, - data; - callback = let mm = model.m, prgr = prgr, itr_n = itr_n - (ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n) - end, - ) - ProgressMeter.finish!(prgr) - - else - tst_one = @timed res = solve(optprob_re, opt, data) - end - ps .= res.u - @info( - "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_one.time, - "garbage collection time (seconds)" = tst_one.gctime, - ) - end - @info( - "Fitting (all epochs) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_epochs.time, - "garbage collection time (seconds)" = tst_epochs.gctime, - ) - end - @info( - "Fitting - Overall", - "elapsed time (seconds)" = tst_overall.time, - "garbage collection time (seconds)" = tst_overall.gctime, - ) - - fitresult = (ps, st) - cache = nothing - report = (stats = tst_overall,) - fitresult, cache, report -end - -function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) - Xnew, Ynew = XYnew - xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) - if model.m.resource isa CUDALibs - gdev = gpu_device() - xnew = gdev(xnew) - ynew = gdev(ynew) - end - (ps, st) = fitresult - - if model.compute_mode <: VectorMode - tst = @timed logp̂x = broadcast( - let mm = model.m, md = TestMode(), ps = ps, st = st - (x, y) -> first(inference(mm, md, x, y, ps, st)) - end, - eachcol(xnew), - eachcol(ynew), - ) - elseif model.compute_mode <: MatrixMode - tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st)) - else - error("Not Implemented") - end - @info( - "Transforming", - "elapsed time (seconds)" = tst.time, - "garbage collection time (seconds)" = tst.gctime, - ) - - DataFrame(; px = exp.(logp̂x)) -end - -MLJBase.metadata_pkg( - CondICNFModel; - package_name = "ContinuousNormalizingFlows", - package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", - package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", - is_pure_julia = true, - package_license = "MIT", - is_wrapper = false, -) -MLJBase.metadata_model( - CondICNFModel; - input_scitype = Tuple{ - Table{AbstractVector{ScientificTypes.Continuous}}, - Table{AbstractVector{ScientificTypes.Continuous}}, - }, - target_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, - output_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, - supports_weights = false, - load_path = "ContinuousNormalizingFlows.CondICNFModel", -) - -# Distributions interface - -struct CondICNFDist <: ICNFDistribution - m::AbstractICNF - mode::Mode - ys::AbstractVecOrMat{<:Real} - ps::Any - st::Any -end - -function CondICNFDist( - mach::Machine{<:CondICNFModel}, - mode::Mode, - ys::AbstractVecOrMat{<:Real}, -) - (ps, st) = fitted_params(mach) - CondICNFDist(mach.model.m, mode, ys, ps, st) -end - -function Base.length(d::CondICNFDist) - d.m.nvars -end -function Base.eltype(d::CondICNFDist) - first(typeof(d.m).parameters) -end -function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end -end -function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - broadcast(let d = d - x -> Distributions._logpdf(d, x) - end, eachcol(A)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) - else - error("Not Implemented") - end -end -function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ys, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end -end -function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(broadcast(let rng = rng, d = d - x -> Distributions._rand!(rng, d, x) - end, eachcol(A))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end -end +export CondICNFModel, CondICNFDist + +# MLJ interface + +mutable struct CondICNFModel <: MLJICNF + m::AbstractICNF + loss::Function + + optimizers::AbstractVector + n_epochs::Int + adtype::ADTypes.AbstractADType + + use_batch::Bool + batch_size::Int + have_callback::Bool + + compute_mode::Type{<:ComputeMode} +end + +function CondICNFModel( + m::AbstractICNF{<:AbstractFloat, CM}, + loss::Function = loss; + optimizers::AbstractVector = Any[Optimisers.Lion(),], + n_epochs::Int = 300, + adtype::ADTypes.AbstractADType = AutoZygote(), + use_batch::Bool = true, + batch_size::Int = 32, + have_callback::Bool = true, +) where {CM <: ComputeMode} + CondICNFModel( + m, + loss, + optimizers, + n_epochs, + adtype, + use_batch, + batch_size, + have_callback, + CM, + ) +end + +function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) + X, Y = XY + x = collect(transpose(MLJModelInterface.matrix(X))) + y = collect(transpose(MLJModelInterface.matrix(Y))) + ps, st = LuxCore.setup(model.m.rng, model.m) + ps = ComponentArray(ps) + if model.m.resource isa CUDALibs + gdev = gpu_device() + x = gdev(x) + y = gdev(y) + ps = gdev(ps) + st = gdev(st) + end + optfunc = OptimizationFunction( + let mm = model.m, md = TrainMode(), st = st + (u, p, xs_, ys_) -> model.loss(mm, md, xs_, ys_, u, st) + end, + model.adtype, + ) + optprob = OptimizationProblem(optfunc, ps) + tst_overall = @timed for opt in model.optimizers + tst_epochs = @timed for ep in 1:(model.n_epochs) + if model.use_batch + if model.compute_mode <: VectorMode + data = DataLoader( + (x, y); + batchsize = -1, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + elseif model.compute_mode <: MatrixMode + data = DataLoader( + (x, y); + batchsize = model.batch_size, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + else + error("Not Implemented") + end + else + data = [(x, y)] + end + optprob_re = remake(optprob; u0 = ps) + if model.have_callback + prgr = Progress( + length(data); + desc = "Fitting (epoch: $ep of $(model.n_epochs)): ", + showspeed = true, + ) + itr_n = ones(Int) + tst_one = @timed res = solve( + optprob_re, + opt, + data; + callback = let mm = model.m, prgr = prgr, itr_n = itr_n + (ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n) + end, + ) + ProgressMeter.finish!(prgr) + + else + tst_one = @timed res = solve(optprob_re, opt, data) + end + ps .= res.u + @info( + "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_one.time, + "garbage collection time (seconds)" = tst_one.gctime, + ) + end + @info( + "Fitting (all epochs) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_epochs.time, + "garbage collection time (seconds)" = tst_epochs.gctime, + ) + end + @info( + "Fitting - Overall", + "elapsed time (seconds)" = tst_overall.time, + "garbage collection time (seconds)" = tst_overall.gctime, + ) + + fitresult = (ps, st) + cache = nothing + report = (stats = tst_overall,) + fitresult, cache, report +end + +function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) + Xnew, Ynew = XYnew + xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) + ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) + if model.m.resource isa CUDALibs + gdev = gpu_device() + xnew = gdev(xnew) + ynew = gdev(ynew) + end + (ps, st) = fitresult + + if model.compute_mode <: VectorMode + tst = @timed logp̂x = broadcast( + let mm = model.m, md = TestMode(), ps = ps, st = st + (x, y) -> first(inference(mm, md, x, y, ps, st)) + end, + eachcol(xnew), + eachcol(ynew), + ) + elseif model.compute_mode <: MatrixMode + tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st)) + else + error("Not Implemented") + end + @info( + "Transforming", + "elapsed time (seconds)" = tst.time, + "garbage collection time (seconds)" = tst.gctime, + ) + + DataFrame(; px = exp.(logp̂x)) +end + +MLJBase.metadata_pkg( + CondICNFModel; + package_name = "ContinuousNormalizingFlows", + package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", + package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", + is_pure_julia = true, + package_license = "MIT", + is_wrapper = false, +) +MLJBase.metadata_model( + CondICNFModel; + input_scitype = Tuple{ + Table{AbstractVector{ScientificTypes.Continuous}}, + Table{AbstractVector{ScientificTypes.Continuous}}, + }, + target_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, + output_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, + supports_weights = false, + load_path = "ContinuousNormalizingFlows.CondICNFModel", +) + +# Distributions interface + +struct CondICNFDist <: ICNFDistribution + m::AbstractICNF + mode::Mode + ys::AbstractVecOrMat{<:Real} + ps::Any + st::NamedTuple +end + +function CondICNFDist( + mach::Machine{<:CondICNFModel}, + mode::Mode, + ys::AbstractVecOrMat{<:Real}, +) + (ps, st) = fitted_params(mach) + CondICNFDist(mach.model.m, mode, ys, ps, st) +end + +function Base.length(d::CondICNFDist) + d.m.nvars +end +function Base.eltype(d::CondICNFDist) + first(typeof(d.m).parameters) +end +function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(Distributions._logpdf(d, hcat(x))) + else + error("Not Implemented") + end +end +function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + broadcast(let d = d + x -> Distributions._logpdf(d, x) + end, eachcol(A)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) + else + error("Not Implemented") + end +end +function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + x .= generate(d.m, d.mode, d.ys, d.ps, d.st) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + x .= Distributions._rand!(rng, d, hcat(x)) + else + error("Not Implemented") + end +end +function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + A .= hcat(broadcast(let rng = rng, d = d + x -> Distributions._rand!(rng, d, x) + end, eachcol(A))...) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) + else + error("Not Implemented") + end +end diff --git a/src/core_icnf.jl b/src/cores/core_icnf.jl similarity index 96% rename from src/core_icnf.jl rename to src/cores/core_icnf.jl index 8985069a..1c103a36 100644 --- a/src/core_icnf.jl +++ b/src/cores/core_icnf.jl @@ -1,240 +1,240 @@ -export ICNFModel, ICNFDist - -# MLJ interface - -mutable struct ICNFModel <: MLJICNF - m::AbstractICNF - loss::Function - - optimizers::AbstractVector - n_epochs::Int - adtype::ADTypes.AbstractADType - - use_batch::Bool - batch_size::Int - have_callback::Bool - - compute_mode::Type{<:ComputeMode} -end - -function ICNFModel( - m::AbstractICNF{<:AbstractFloat, CM}, - loss::Function = loss; - optimizers::AbstractVector = Any[Optimisers.Lion(),], - n_epochs::Int = 300, - adtype::ADTypes.AbstractADType = AutoZygote(), - use_batch::Bool = true, - batch_size::Int = 32, - have_callback::Bool = true, -) where {CM <: ComputeMode} - ICNFModel( - m, - loss, - optimizers, - n_epochs, - adtype, - use_batch, - batch_size, - have_callback, - CM, - ) -end - -function MLJModelInterface.fit(model::ICNFModel, verbosity, X) - x = collect(transpose(MLJModelInterface.matrix(X))) - ps, st = LuxCore.setup(model.m.rng, model.m) - ps = ComponentArray(ps) - if model.m.resource isa CUDALibs - gdev = gpu_device() - x = gdev(x) - ps = gdev(ps) - st = gdev(st) - end - optfunc = OptimizationFunction( - let mm = model.m, md = TrainMode(), st = st - (u, p, xs_) -> model.loss(mm, md, xs_, u, st) - end, - model.adtype, - ) - optprob = OptimizationProblem(optfunc, ps) - - tst_overall = @timed for opt in model.optimizers - tst_epochs = @timed for ep in 1:(model.n_epochs) - if model.use_batch - if model.compute_mode <: VectorMode - data = DataLoader( - (x,); - batchsize = -1, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - elseif model.compute_mode <: MatrixMode - data = DataLoader( - (x,); - batchsize = model.batch_size, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - else - error("Not Implemented") - end - else - data = [(x,)] - end - optprob_re = remake(optprob; u0 = ps) - if model.have_callback - prgr = Progress( - length(data); - desc = "Fitting (epoch: $ep of $(model.n_epochs)): ", - showspeed = true, - ) - itr_n = ones(Int) - tst_one = @timed res = solve( - optprob_re, - opt, - data; - callback = let mm = model.m, prgr = prgr, itr_n = itr_n - (ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n) - end, - ) - ProgressMeter.finish!(prgr) - else - tst_one = @timed res = solve(optprob_re, opt, data) - end - ps .= res.u - @info( - "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_one.time, - "garbage collection time (seconds)" = tst_one.gctime, - ) - end - @info( - "Fitting (all epochs) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_epochs.time, - "garbage collection time (seconds)" = tst_epochs.gctime, - ) - end - @info( - "Fitting - Overall", - "elapsed time (seconds)" = tst_overall.time, - "garbage collection time (seconds)" = tst_overall.gctime, - ) - - fitresult = (ps, st) - cache = nothing - report = (stats = tst_overall,) - fitresult, cache, report -end - -function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) - xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - if model.m.resource isa CUDALibs - gdev = gpu_device() - xnew = gdev(xnew) - end - (ps, st) = fitresult - - if model.compute_mode <: VectorMode - tst = @timed logp̂x = broadcast( - let mm = model.m, md = TestMode(), ps = ps, st = st - x -> first(inference(model.m, TestMode(), x, ps, st)) - end, - eachcol(xnew), - ) - elseif model.compute_mode <: MatrixMode - tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ps, st)) - else - error("Not Implemented") - end - - @info( - "Transforming", - "elapsed time (seconds)" = tst.time, - "garbage collection time (seconds)" = tst.gctime, - ) - - DataFrame(; px = exp.(logp̂x)) -end - -MLJBase.metadata_pkg( - ICNFModel; - package_name = "ContinuousNormalizingFlows", - package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", - package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", - is_pure_julia = true, - package_license = "MIT", - is_wrapper = false, -) -MLJBase.metadata_model( - ICNFModel; - input_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, - target_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, - output_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, - supports_weights = false, - load_path = "ContinuousNormalizingFlows.ICNFModel", -) - -# Distributions interface - -struct ICNFDist <: ICNFDistribution - m::AbstractICNF - mode::Mode - ps::Any - st::Any -end - -function ICNFDist(mach::Machine{<:ICNFModel}, mode::Mode) - (ps, st) = fitted_params(mach) - ICNFDist(mach.model.m, mode, ps, st) -end - -function Base.length(d::ICNFDist) - d.m.nvars -end -function Base.eltype(d::ICNFDist) - first(typeof(d.m).parameters) -end -function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end -end -function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - broadcast(let d = d - x -> Distributions._logpdf(d, x) - end, eachcol(A)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ps, d.st)) - else - error("Not Implemented") - end -end -function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end -end -function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(broadcast(let rng = rng, d = d - x -> Distributions._rand!(rng, d, x) - end, eachcol(A))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end -end +export ICNFModel, ICNFDist + +# MLJ interface + +mutable struct ICNFModel <: MLJICNF + m::AbstractICNF + loss::Function + + optimizers::AbstractVector + n_epochs::Int + adtype::ADTypes.AbstractADType + + use_batch::Bool + batch_size::Int + have_callback::Bool + + compute_mode::Type{<:ComputeMode} +end + +function ICNFModel( + m::AbstractICNF{<:AbstractFloat, CM}, + loss::Function = loss; + optimizers::AbstractVector = Any[Optimisers.Lion(),], + n_epochs::Int = 300, + adtype::ADTypes.AbstractADType = AutoZygote(), + use_batch::Bool = true, + batch_size::Int = 32, + have_callback::Bool = true, +) where {CM <: ComputeMode} + ICNFModel( + m, + loss, + optimizers, + n_epochs, + adtype, + use_batch, + batch_size, + have_callback, + CM, + ) +end + +function MLJModelInterface.fit(model::ICNFModel, verbosity, X) + x = collect(transpose(MLJModelInterface.matrix(X))) + ps, st = LuxCore.setup(model.m.rng, model.m) + ps = ComponentArray(ps) + if model.m.resource isa CUDALibs + gdev = gpu_device() + x = gdev(x) + ps = gdev(ps) + st = gdev(st) + end + optfunc = OptimizationFunction( + let mm = model.m, md = TrainMode(), st = st + (u, p, xs_) -> model.loss(mm, md, xs_, u, st) + end, + model.adtype, + ) + optprob = OptimizationProblem(optfunc, ps) + + tst_overall = @timed for opt in model.optimizers + tst_epochs = @timed for ep in 1:(model.n_epochs) + if model.use_batch + if model.compute_mode <: VectorMode + data = DataLoader( + (x,); + batchsize = -1, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + elseif model.compute_mode <: MatrixMode + data = DataLoader( + (x,); + batchsize = model.batch_size, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + else + error("Not Implemented") + end + else + data = [(x,)] + end + optprob_re = remake(optprob; u0 = ps) + if model.have_callback + prgr = Progress( + length(data); + desc = "Fitting (epoch: $ep of $(model.n_epochs)): ", + showspeed = true, + ) + itr_n = ones(Int) + tst_one = @timed res = solve( + optprob_re, + opt, + data; + callback = let mm = model.m, prgr = prgr, itr_n = itr_n + (ps_, l_) -> callback_f(ps_, l_, mm, prgr, itr_n) + end, + ) + ProgressMeter.finish!(prgr) + else + tst_one = @timed res = solve(optprob_re, opt, data) + end + ps .= res.u + @info( + "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_one.time, + "garbage collection time (seconds)" = tst_one.gctime, + ) + end + @info( + "Fitting (all epochs) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_epochs.time, + "garbage collection time (seconds)" = tst_epochs.gctime, + ) + end + @info( + "Fitting - Overall", + "elapsed time (seconds)" = tst_overall.time, + "garbage collection time (seconds)" = tst_overall.gctime, + ) + + fitresult = (ps, st) + cache = nothing + report = (stats = tst_overall,) + fitresult, cache, report +end + +function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) + xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) + if model.m.resource isa CUDALibs + gdev = gpu_device() + xnew = gdev(xnew) + end + (ps, st) = fitresult + + if model.compute_mode <: VectorMode + tst = @timed logp̂x = broadcast( + let mm = model.m, md = TestMode(), ps = ps, st = st + x -> first(inference(model.m, TestMode(), x, ps, st)) + end, + eachcol(xnew), + ) + elseif model.compute_mode <: MatrixMode + tst = @timed logp̂x = first(inference(model.m, TestMode(), xnew, ps, st)) + else + error("Not Implemented") + end + + @info( + "Transforming", + "elapsed time (seconds)" = tst.time, + "garbage collection time (seconds)" = tst.gctime, + ) + + DataFrame(; px = exp.(logp̂x)) +end + +MLJBase.metadata_pkg( + ICNFModel; + package_name = "ContinuousNormalizingFlows", + package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", + package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", + is_pure_julia = true, + package_license = "MIT", + is_wrapper = false, +) +MLJBase.metadata_model( + ICNFModel; + input_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, + target_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, + output_scitype = Table{AbstractVector{ScientificTypes.Continuous}}, + supports_weights = false, + load_path = "ContinuousNormalizingFlows.ICNFModel", +) + +# Distributions interface + +struct ICNFDist <: ICNFDistribution + m::AbstractICNF + mode::Mode + ps::Any + st::NamedTuple +end + +function ICNFDist(mach::Machine{<:ICNFModel}, mode::Mode) + (ps, st) = fitted_params(mach) + ICNFDist(mach.model.m, mode, ps, st) +end + +function Base.length(d::ICNFDist) + d.m.nvars +end +function Base.eltype(d::ICNFDist) + first(typeof(d.m).parameters) +end +function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + first(inference(d.m, d.mode, x, d.ps, d.st)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(Distributions._logpdf(d, hcat(x))) + else + error("Not Implemented") + end +end +function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + broadcast(let d = d + x -> Distributions._logpdf(d, x) + end, eachcol(A)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(inference(d.m, d.mode, A, d.ps, d.st)) + else + error("Not Implemented") + end +end +function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + x .= generate(d.m, d.mode, d.ps, d.st) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + x .= Distributions._rand!(rng, d, hcat(x)) + else + error("Not Implemented") + end +end +function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + A .= hcat(broadcast(let rng = rng, d = d + x -> Distributions._rand!(rng, d, x) + end, eachcol(A))...) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) + else + error("Not Implemented") + end +end diff --git a/src/icnf.jl b/src/icnf.jl index 11618645..36603810 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -8,7 +8,7 @@ struct Planar{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end struct CondPlanar{ T <: AbstractFloat, CM <: ComputeMode, @@ -17,7 +17,7 @@ struct CondPlanar{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end struct FFJORD{ T <: AbstractFloat, @@ -27,7 +27,7 @@ struct FFJORD{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end struct CondFFJORD{ T <: AbstractFloat, CM <: ComputeMode, @@ -36,7 +36,7 @@ struct CondFFJORD{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end struct RNODE{ T <: AbstractFloat, @@ -46,7 +46,7 @@ struct RNODE{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end struct CondRNODE{ T <: AbstractFloat, CM <: ComputeMode, @@ -55,7 +55,7 @@ struct CondRNODE{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end """ Implementation of ICNF. @@ -633,7 +633,7 @@ end mode::TrainMode, xs::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ @@ -645,7 +645,7 @@ end xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ @@ -656,7 +656,7 @@ end mode::TrainMode, xs::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) @@ -668,7 +668,7 @@ end xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, - st::Any, + st::NamedTuple, ) logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) diff --git a/src/cond_layer.jl b/src/layers/cond_layer.jl similarity index 68% rename from src/cond_layer.jl rename to src/layers/cond_layer.jl index ff6a2bb6..fabaaf2b 100644 --- a/src/cond_layer.jl +++ b/src/layers/cond_layer.jl @@ -1,9 +1,11 @@ -struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <: - LuxCore.AbstractExplicitContainerLayer{(:nn,)} - nn::NN - ys::AT -end - -@inline function (m::CondLayer)(z::AbstractArray, ps::Any, st::Any) - LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) -end +export CondLayer + +struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <: + LuxCore.AbstractExplicitContainerLayer{(:nn,)} + nn::NN + ys::AT +end + +@inline function (m::CondLayer)(z::AbstractVecOrMat, ps::Any, st::NamedTuple) + LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) +end diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl new file mode 100644 index 00000000..f2939ae9 --- /dev/null +++ b/src/layers/mul_layer.jl @@ -0,0 +1,37 @@ +export MulLayer + +struct MulLayer{F1, F2, NVARS <: Int} <: LuxCore.AbstractExplicitLayer + activation::F1 + nvars::NVARS + init_weight::F2 +end + +function MulLayer( + nvars::Int, + activation::Any = identity; + init_weight::Any = Lux.glorot_uniform, + allow_fast_activation::Bool = true, +) + activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) + MulLayer{typeof(activation), typeof(init_weight), typeof(nvars)}( + activation, + nvars, + init_weight, + ) +end + +function LuxCore.initialparameters(rng::AbstractRNG, m::MulLayer) + (weight = m.init_weight(rng, m.nvars, m.nvars),) +end + +function LuxCore.parameterlength(m::MulLayer) + m.nvars * m.nvars +end + +function LuxCore.outputsize(m::MulLayer) + (m.nvars,) +end + +@inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) + return Lux.__apply_activation(m.activation, Octavian.matmul(ps.weight, x)), st +end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl new file mode 100644 index 00000000..c1dc3610 --- /dev/null +++ b/src/layers/planar_layer.jl @@ -0,0 +1,105 @@ +export PlanarLayer + +""" +Implementation of Planar Layer from + +[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) +""" +struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <: + LuxCore.AbstractExplicitLayer + activation::F1 + nvars::NVARS + init_weight::F2 + init_bias::F3 + n_cond::NVARS +end + +function PlanarLayer( + nvars::Int, + activation::Any = identity; + init_weight::Any = Lux.glorot_uniform, + init_bias::Any = Lux.zeros32, + use_bias::Bool = true, + allow_fast_activation::Bool = true, + n_cond::Int = 0, +) + activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) + PlanarLayer{ + use_bias, + !iszero(n_cond), + typeof(activation), + typeof(init_weight), + typeof(init_bias), + typeof(nvars), + }( + activation, + nvars, + init_weight, + init_bias, + n_cond, + ) +end + +function LuxCore.initialparameters( + rng::AbstractRNG, + layer::PlanarLayer{use_bias, cond}, +) where {use_bias, cond} + ifelse( + use_bias, + ( + u = layer.init_weight(rng, layer.nvars), + w = layer.init_weight( + rng, + ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), + ), + b = layer.init_bias(rng, 1), + ), + ( + u = layer.init_weight(rng, layer.nvars), + w = layer.init_weight( + rng, + ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), + ), + ), + ) +end + +function LuxCore.parameterlength(m::PlanarLayer{use_bias, cond}) where {use_bias, cond} + m.nvars + ifelse(cond, (m.nvars + m.n_cond), m.nvars) + ifelse(use_bias, 1, 0) +end + +function LuxCore.outputsize(m::PlanarLayer) + (m.nvars,) +end + +@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) + ps.u * Lux.__apply_activation(m.activation, (ps.w ⋅ z) + only(ps.b)), st +end + +@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) + ps.u * Lux.__apply_activation(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st +end + +@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) + ps.u * Lux.__apply_activation(m.activation, ps.w ⋅ z), st +end + +@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) + ps.u * Lux.__apply_activation(m.activation, transpose(ps.w) * z), st +end + +@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) + Lux.__apply_activation(m.activation, (ps.w ⋅ z) + only(ps.b)), st +end + +@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) + Lux.__apply_activation(m.activation, muladd(transpose(ps.w), z, only(ps.b))), st +end + +@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) + Lux.__apply_activation(m.activation, ps.w ⋅ z), st +end + +@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) + Lux.__apply_activation(m.activation, transpose(ps.w) * z), st +end diff --git a/src/planar_layer.jl b/src/planar_layer.jl deleted file mode 100644 index 2bc42b9e..00000000 --- a/src/planar_layer.jl +++ /dev/null @@ -1,163 +0,0 @@ -export PlanarLayer - -""" -Implementation of Planar Layer from - -[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) -""" -struct PlanarLayer{use_bias, cond, F1, F2, F3} <: LuxCore.AbstractExplicitLayer - activation::F1 - nvars::Int - init_weight::F2 - init_bias::F3 - n_cond::Int -end - -function PlanarLayer( - nvars::Int, - activation = identity; - init_weight = Lux.glorot_uniform, - init_bias = Lux.zeros32, - use_bias::Bool = true, - allow_fast_activation::Bool = true, - n_cond::Int = 0, -) - activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) - PlanarLayer{ - use_bias, - !iszero(n_cond), - typeof(activation), - typeof(init_weight), - typeof(init_bias), - }( - activation, - nvars, - init_weight, - init_bias, - n_cond, - ) -end - -function LuxCore.initialparameters( - rng::AbstractRNG, - layer::PlanarLayer{use_bias, cond}, -) where {use_bias, cond} - ifelse( - use_bias, - ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), - b = layer.init_bias(rng, 1), - ), - ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), - ), - ) -end - -@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::Any) - ps.u * m.activation((ps.w ⋅ z) + only(ps.b)), st -end - -@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::Any) - ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st -end - -@inline function (::PlanarLayer{true, cond, typeof(identity)})( - z::AbstractVector, - ps::Any, - st::Any, -) where {cond} - ps.u * ((ps.w ⋅ z) + only(ps.b)), st -end - -@inline function (::PlanarLayer{true, cond, typeof(identity)})( - z::AbstractMatrix, - ps::Any, - st::Any, -) where {cond} - ps.u * muladd(transpose(ps.w), z, only(ps.b)), st -end - -@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::Any) - ps.u * m.activation(ps.w ⋅ z), st -end - -@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::Any) - ps.u * m.activation.(transpose(ps.w) * z), st -end - -@inline function (::PlanarLayer{false, cond, typeof(identity)})( - z::AbstractVector, - ps::Any, - st::Any, -) where {cond} - ps.u * (ps.w ⋅ z), st -end - -@inline function (::PlanarLayer{false, cond, typeof(identity)})( - z::AbstractMatrix, - ps::Any, - st::Any, -) where {cond} - ps.u * (transpose(ps.w) * z), st -end - -@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::Any) - m.activation((ps.w ⋅ z) + only(ps.b)), st -end - -@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::Any) - m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st -end - -@inline function pl_h( - ::PlanarLayer{true, cond, typeof(identity)}, - z::AbstractVector, - ps::Any, - st::Any, -) where {cond} - (ps.w ⋅ z) + only(ps.b), st -end - -@inline function pl_h( - ::PlanarLayer{true, cond, typeof(identity)}, - z::AbstractMatrix, - ps::Any, - st::Any, -) where {cond} - muladd(transpose(ps.w), z, only(ps.b)), st -end - -@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::Any) - m.activation(ps.w ⋅ z), st -end - -@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::Any) - m.activation.(transpose(ps.w) * z), st -end - -@inline function pl_h( - ::PlanarLayer{false, cond, typeof(identity)}, - z::AbstractVector, - ps::Any, - st::Any, -) where {cond} - (ps.w ⋅ z), st -end - -@inline function pl_h( - ::PlanarLayer{false, cond, typeof(identity)}, - z::AbstractMatrix, - ps::Any, - st::Any, -) where {cond} - (transpose(ps.w) * z), st -end diff --git a/src/types.jl b/src/types.jl index 57090a27..851a46d4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -26,15 +26,6 @@ struct SDJacVecMatrixMode <: SDMatrixMode end struct ZygoteVectorMode <: VectorMode end struct ZygoteMatrixMode <: MatrixMode end -abstract type AbstractFlows{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: LuxCore.AbstractExplicitContainerLayer{(:nn,)} end abstract type AbstractICNF{ T <: AbstractFloat, CM <: ComputeMode, @@ -43,7 +34,7 @@ abstract type AbstractICNF{ AUGMENTED, STEER, NORM_Z_AUG, -} <: AbstractFlows{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +} <: LuxCore.AbstractExplicitContainerLayer{(:nn,)} end # MLJ interface diff --git a/src/utils.jl b/src/utils.jl index af92b740..3acb8915 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,5 @@ function jacobian_batched( - icnf::AbstractFlows{T, <:SDVecJacMatrixMode}, + icnf::AbstractICNF{T, <:SDVecJacMatrixMode}, f, xs::AbstractMatrix{<:Real}, ) where {T <: AbstractFloat} @@ -17,7 +17,7 @@ function jacobian_batched( end function jacobian_batched( - icnf::AbstractFlows{T, <:SDJacVecMatrixMode}, + icnf::AbstractICNF{T, <:SDJacVecMatrixMode}, f, xs::AbstractMatrix{<:Real}, ) where {T <: AbstractFloat} @@ -35,7 +35,7 @@ function jacobian_batched( end function jacobian_batched( - ::AbstractFlows{T, <:ZygoteMatrixMode}, + ::AbstractICNF{T, <:ZygoteMatrixMode}, f, xs::AbstractMatrix{<:Real}, ) where {T <: AbstractFloat} diff --git a/test/call_tests.jl b/test/call_tests.jl index 9ebd7bbc..82b3710a 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -1,18 +1,18 @@ @testset "Call Tests" begin mts = if GROUP == "RNODE" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[RNODE] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[RNODE] elseif GROUP == "FFJORD" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[FFJORD] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[FFJORD] elseif GROUP == "Planar" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[Planar] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[Planar] elseif GROUP == "CondRNODE" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondRNODE] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondRNODE] elseif GROUP == "CondFFJORD" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondFFJORD] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondFFJORD] elseif GROUP == "CondPlanar" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondPlanar] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondPlanar] else - Type{<:ContinuousNormalizingFlows.AbstractFlows}[ + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ RNODE, FFJORD, Planar, diff --git a/test/fit_tests.jl b/test/fit_tests.jl index a0a4d09a..a76f6a26 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -1,18 +1,18 @@ @testset "Fit Tests" begin mts = if GROUP == "RNODE" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[RNODE] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[RNODE] elseif GROUP == "FFJORD" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[FFJORD] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[FFJORD] elseif GROUP == "Planar" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[Planar] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[Planar] elseif GROUP == "CondRNODE" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondRNODE] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondRNODE] elseif GROUP == "CondFFJORD" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondFFJORD] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondFFJORD] elseif GROUP == "CondPlanar" - Type{<:ContinuousNormalizingFlows.AbstractFlows}[CondPlanar] + Type{<:ContinuousNormalizingFlows.AbstractICNF}[CondPlanar] else - Type{<:ContinuousNormalizingFlows.AbstractFlows}[ + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ RNODE, FFJORD, Planar,