diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index b6c0ce79..581f65b3 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -29,7 +29,6 @@ icnf = construct( compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, rng, ) ps, st = Lux.setup(icnf.rng, icnf) @@ -61,7 +60,6 @@ icnf2 = construct( compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, rng, ) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 5cf890ad..343415b2 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -31,7 +31,6 @@ using AbstractDifferentiation, Statistics, Zygote -include("defaults.jl") include("types.jl") include("planar_layer.jl") diff --git a/src/base.jl b/src/base.jl index 2634e098..8f315a74 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,10 @@ function construct( AutoForwardDiff(), AutoZygote(), ), - sol_kwargs::NamedTuple = sol_kwargs_defaults.medium, + sol_kwargs::NamedTuple = ( + save_everystep = false, + alg = Tsit5(; thread = OrdinaryDiffEq.True()), + ), rng::AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = convert(data_type, 1e-2), λ₂::AbstractFloat = convert(data_type, 1e-2), diff --git a/src/defaults.jl b/src/defaults.jl deleted file mode 100644 index 1a47e94e..00000000 --- a/src/defaults.jl +++ /dev/null @@ -1,40 +0,0 @@ -const sol_kwargs_defaults = ( - high = ( - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = QuadratureAdjoint(; autodiff = true, autojacvec = ZygoteVJP()), - reltol = eps(one(Float32)), - abstol = eps(one(Float32)) * eps(one(Float32)), - maxiters = typemax(Int32), - ), - medium = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = InterpolatingAdjoint(; - autodiff = true, - autojacvec = ZygoteVJP(), - checkpointing = true, - ), - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), - ), - medium_noad = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), - ), - low = ( - save_everystep = false, - alg = Tsit5(; thread = OrdinaryDiffEq.True()), - sensealg = InterpolatingAdjoint(; - autodiff = true, - autojacvec = ZygoteVJP(), - checkpointing = true, - ), - reltol = sqrt(sqrt(eps(one(Float32)))), - abstol = sqrt(eps(one(Float32))), - maxiters = typemax(Int32), - ), -) diff --git a/test/call_tests.jl b/test/call_tests.jl index 1e987a6a..b65fcf55 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -106,18 +106,8 @@ inplace, resource, steer_rate = convert(data_type, 0.1), - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, - ), - construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ), + construct(mt, nn, nvars; data_type, compute_mode, inplace, resource), ) ps, st = Lux.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/test/fit_tests.jl b/test/fit_tests.jl index 9eeed02a..788b55cb 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -103,18 +103,8 @@ inplace, resource, steer_rate = convert(data_type, 0.1), - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, - ), - construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ), + construct(mt, nn, nvars; data_type, compute_mode, inplace, resource), ) if mt <: ContinuousNormalizingFlows.AbstractCondICNF model = CondICNFModel(icnf; n_epochs, adtype) diff --git a/test/instability_tests.jl b/test/instability_tests.jl index bdf5c54f..1865ec21 100644 --- a/test/instability_tests.jl +++ b/test/instability_tests.jl @@ -19,7 +19,6 @@ compute_mode = ZygoteMatrixMode, tspan = (0.0f0, 13.0f0), steer_rate = 0.1f0, - sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad, ) ps, st = Lux.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps)