diff --git a/README.md b/README.md index c15eccda..893ebd65 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ model = ICNFModel( icnf; batch_size = 256, # have bigger batchs # n_epochs = 100, # have less epochs - # optimizers = [Lion()], # use a different optimizer + # optimizers = (Lion(),), # use a different optimizer # adtype = AutoForwardDiff(), # use ForwardDiff ) mach = machine(model, df) diff --git a/src/cores/core_cond_icnf.jl b/src/cores/core_cond_icnf.jl index ddba2d88..8f1c4cd8 100644 --- a/src/cores/core_cond_icnf.jl +++ b/src/cores/core_cond_icnf.jl @@ -6,7 +6,7 @@ mutable struct CondICNFModel <: MLJICNF m::AbstractICNF loss::Function - optimizers::AbstractVector + optimizers::Tuple n_epochs::Int adtype::ADTypes.AbstractADType @@ -19,7 +19,7 @@ end function CondICNFModel( m::AbstractICNF{<:AbstractFloat, CM}, loss::Function = loss; - optimizers::AbstractVector = Any[PolyOpt(),], + optimizers::Tuple = (PolyOpt(),), n_epochs::Int = 300, adtype::ADTypes.AbstractADType = AutoZygote(), use_batch::Bool = true, diff --git a/src/cores/core_icnf.jl b/src/cores/core_icnf.jl index 93af59d2..49cd67e6 100644 --- a/src/cores/core_icnf.jl +++ b/src/cores/core_icnf.jl @@ -6,7 +6,7 @@ mutable struct ICNFModel <: MLJICNF m::AbstractICNF loss::Function - optimizers::AbstractVector + optimizers::Tuple n_epochs::Int adtype::ADTypes.AbstractADType @@ -19,7 +19,7 @@ end function ICNFModel( m::AbstractICNF{<:AbstractFloat, CM}, loss::Function = loss; - optimizers::AbstractVector = Any[PolyOpt(),], + optimizers::Tuple = (PolyOpt(),), n_epochs::Int = 300, adtype::ADTypes.AbstractADType = AutoZygote(), use_batch::Bool = true,