From 342d377ea4a5b0016044de66873bde6ee7b5dde8 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 2 Mar 2024 12:04:28 +0330 Subject: [PATCH 1/2] remove defaults --- Project.toml | 4 ++-- benchmark/benchmarks.jl | 2 -- src/ContinuousNormalizingFlows.jl | 3 +-- src/base.jl | 2 +- src/defaults.jl | 40 ------------------------------- test/call_tests.jl | 12 +--------- test/fit_tests.jl | 12 +--------- test/instability_tests.jl | 1 - 8 files changed, 6 insertions(+), 70 deletions(-) delete mode 100644 src/defaults.jl diff --git a/Project.toml b/Project.toml index 6e7b62f4..97ef43f9 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -24,7 +25,6 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -49,6 +49,7 @@ ComponentArrays = "0.15.5" ComputationalResources = "0.3" DataFrames = "1" Dates = "1" +DifferentialEquations = "7" Distributions = "0.25" DistributionsAD = "0.6" FillArrays = "1" @@ -62,7 +63,6 @@ NNlib = "0.9" Optimisers = "0.3" Optimization = "3.15" OptimizationOptimisers = "0.1, 0.2" -OrdinaryDiffEq = "6" ProgressMeter = "1" Random = "1" SciMLBase = "2" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index b6c0ce79..581f65b3 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -29,7 +29,6 @@ icnf = construct( compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, rng, ) ps, st = Lux.setup(icnf.rng, icnf) @@ -61,7 +60,6 @@ icnf2 = construct( compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, rng, ) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 5cf890ad..9562cf17 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -7,6 +7,7 @@ using AbstractDifferentiation, ComponentArrays, ComputationalResources, DataFrames, + DifferentialEquations, Dates, Distributions, DistributionsAD, @@ -21,7 +22,6 @@ using AbstractDifferentiation, Optimisers, Optimization, OptimizationOptimisers, - OrdinaryDiffEq, ProgressMeter, Random, ScientificTypes, @@ -31,7 +31,6 @@ using AbstractDifferentiation, Statistics, Zygote -include("defaults.jl") include("types.jl") include("planar_layer.jl") diff --git a/src/base.jl b/src/base.jl index 2634e098..0d385b55 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,7 @@ function construct( AutoForwardDiff(), AutoZygote(), ), - sol_kwargs::NamedTuple = sol_kwargs_defaults.medium, + sol_kwargs::NamedTuple = (save_everystep = false,), rng::AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = convert(data_type, 1e-2), λ₂::AbstractFloat = convert(data_type, 1e-2), diff --git a/src/defaults.jl b/src/defaults.jl deleted file mode 100644 index 1a47e94e..00000000 --- a/src/defaults.jl +++ /dev/null @@ -1,40 +0,0 @@ -const sol_kwargs_defaults = ( - high = ( - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = QuadratureAdjoint(; autodiff = true, autojacvec = ZygoteVJP()), - reltol = eps(one(Float32)), - abstol = eps(one(Float32)) * eps(one(Float32)), - maxiters = typemax(Int32), - ), - medium = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = InterpolatingAdjoint(; - autodiff = true, - autojacvec = ZygoteVJP(), - checkpointing = true, - ), - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), - ), - medium_noad = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), - ), - low = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = InterpolatingAdjoint(; - autodiff = true, - autojacvec = ZygoteVJP(), - checkpointing = true, - ), - reltol = sqrt(sqrt(eps(one(Float32)))), - abstol = sqrt(eps(one(Float32))), - maxiters = typemax(Int32), - ), -) diff --git a/test/call_tests.jl b/test/call_tests.jl index 1e987a6a..b65fcf55 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -106,18 +106,8 @@ inplace, resource, steer_rate = convert(data_type, 0.1), - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, - ), - construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ), + construct(mt, nn, nvars; data_type, compute_mode, inplace, resource), ) ps, st = Lux.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/test/fit_tests.jl b/test/fit_tests.jl index 9eeed02a..788b55cb 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -103,18 +103,8 @@ inplace, resource, steer_rate = convert(data_type, 0.1), - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, - ), - construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ), + construct(mt, nn, nvars; data_type, compute_mode, inplace, resource), ) if mt <: ContinuousNormalizingFlows.AbstractCondICNF model = CondICNFModel(icnf; n_epochs, adtype) diff --git a/test/instability_tests.jl b/test/instability_tests.jl index bdf5c54f..1865ec21 100644 --- a/test/instability_tests.jl +++ b/test/instability_tests.jl @@ -19,7 +19,6 @@ compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ) ps, st = Lux.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) From a67cad63927cc8eb7d64ed99872bb890489c944e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 Mar 2024 21:50:38 +0330 Subject: [PATCH 2/2] keep alg --- Project.toml | 4 ++-- src/ContinuousNormalizingFlows.jl | 2 +- src/base.jl | 5 ++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 97ef43f9..6e7b62f4 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -25,6 +24,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -49,7 +49,6 @@ ComponentArrays = "0.15.5" ComputationalResources = "0.3" DataFrames = "1" Dates = "1" -DifferentialEquations = "7" Distributions = "0.25" DistributionsAD = "0.6" FillArrays = "1" @@ -63,6 +62,7 @@ NNlib = "0.9" Optimisers = "0.3" Optimization = "3.15" OptimizationOptimisers = "0.1, 0.2" +OrdinaryDiffEq = "6" ProgressMeter = "1" Random = "1" SciMLBase = "2" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 9562cf17..343415b2 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -7,7 +7,6 @@ using AbstractDifferentiation, ComponentArrays, ComputationalResources, DataFrames, - DifferentialEquations, Dates, Distributions, DistributionsAD, @@ -22,6 +21,7 @@ using AbstractDifferentiation, Optimisers, Optimization, OptimizationOptimisers, + OrdinaryDiffEq, ProgressMeter, Random, ScientificTypes, diff --git a/src/base.jl b/src/base.jl index 0d385b55..8f315a74 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,10 @@ function construct( AutoForwardDiff(), AutoZygote(), ), - sol_kwargs::NamedTuple = (save_everystep = false,), + sol_kwargs::NamedTuple = ( + save_everystep = false, + alg = Tsit5(; thread = OrdinaryDiffEq.True()), + ), rng::AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = convert(data_type, 1e-2), λ₂::AbstractFloat = convert(data_type, 1e-2),