Skip to content

Commit

Permalink
Cleaning & add MulLayer (#390)
Browse files Browse the repository at this point in the history
* cleaning

* paramtype

* fix

* add export
  • Loading branch information
prbzrg authored Mar 27, 2024
1 parent f999611 commit d160318
Show file tree
Hide file tree
Showing 17 changed files with 958 additions and 984 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand Down Expand Up @@ -59,6 +60,7 @@ MLJBase = "1"
MLJModelInterface = "1"
MLUtils = "0.4"
NNlib = "0.9"
Octavian = "0.3.27"
Optimisers = "0.3"
Optimization = "3.15"
OptimizationOptimisers = "0.2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end

@inline function ContinuousNormalizingFlows.base_AT(
::CUDALibs,
::ContinuousNormalizingFlows.AbstractFlows{T},
::ContinuousNormalizingFlows.AbstractICNF{T},
dims...,
) where {T <: AbstractFloat}
CuArray{T}(undef, dims...)
Expand Down
17 changes: 9 additions & 8 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using AbstractDifferentiation,
MLJModelInterface,
MLUtils,
NNlib,
Octavian,
Optimisers,
Optimization,
OptimizationOptimisers,
Expand All @@ -31,22 +32,22 @@ using AbstractDifferentiation,
Statistics,
Zygote

include("types.jl")
include(joinpath("layers", "cond_layer.jl"))
include(joinpath("layers", "planar_layer.jl"))
include(joinpath("layers", "mul_layer.jl"))

include("planar_layer.jl")
include("cond_layer.jl")
include("types.jl")

include("base.jl")
include("base_icnf.jl")

include("core.jl")
include("core_icnf.jl")
include("core_cond_icnf.jl")

include("icnf.jl")

include("utils.jl")

include(joinpath("cores", "core.jl"))
include(joinpath("cores", "core_icnf.jl"))
include(joinpath("cores", "core_cond_icnf.jl"))

"""
Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia
"""
Expand Down
226 changes: 0 additions & 226 deletions src/base.jl

This file was deleted.

Loading

0 comments on commit d160318

Please sign in to comment.