From 01b13d607b7c766857ceec9a417e076fa0042ebf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 31 Jul 2024 12:40:18 +0530 Subject: [PATCH] refactor: format refactor: format --- ext/SciMLBaseZygoteExt.jl | 7 ++++--- src/clock.jl | 2 +- src/ensemble/ensemble_solutions.jl | 6 ++++-- src/problems/linear_problems.jl | 2 +- src/remake.jl | 2 +- src/scimlfunctions.jl | 5 +++-- src/solutions/ode_solutions.jl | 13 +++++++++---- src/solutions/solution_interface.jl | 18 +++++++++--------- test/downstream/comprehensive_indexing.jl | 12 ++++++++---- test/downstream/solution_interface.jl | 7 ++++--- test/serialization_tests.jl | 10 +++++----- 11 files changed, 49 insertions(+), 35 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 950b97dcb..486d219aa 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -174,15 +174,16 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 -, T14, T15}(u, +@adjoint function ODESolution{ + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15} function ODESolutionAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}( + u, args...), ODESolutionAdjoint end diff --git a/src/clock.jl b/src/clock.jl index 2db74dd3a..8085f08a4 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -88,7 +88,7 @@ Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) c = ic.clock - + return @match c begin PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt &SolverStepClock => begin diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 6bb64e04f..50ca209dc 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -211,11 +211,13 @@ end end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i::Integer) +Base.@propagate_inbounds function Base.getindex( + x::AbstractEnsembleSolution, s::Integer, i::Integer) return x.u[s].u[i] end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...) +Base.@propagate_inbounds function Base.getindex( + x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...) return x.u[s][i2, i3, idxs...] end diff --git a/src/problems/linear_problems.jl b/src/problems/linear_problems.jl index 8a1cc5d35..ebb5e3c97 100644 --- a/src/problems/linear_problems.jl +++ b/src/problems/linear_problems.jl @@ -75,4 +75,4 @@ function LinearProblem(A, b, args...; kwargs...) else LinearProblem{isinplace(A, 4)}(A, b, args...; kwargs...) end -end \ No newline at end of file +end diff --git a/src/remake.jl b/src/remake.jl index 621ca46a6..aadb2efd7 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -476,7 +476,7 @@ anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() function _updated_u0_p_internal( - prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false) + prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false) return state_values(prob), parameter_values(prob) end function _updated_u0_p_internal( diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 176e86b11..1661d596c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2615,7 +2615,8 @@ end typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, mass_matrix, + typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, initializeprobmap) @@ -2649,7 +2650,7 @@ function SplitFunction{iip, specialize}(f1, f2; sys = __has_sys(f1) ? f1.sys : nothing, initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing, initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing - ) where {iip, +) where {iip, specialize } sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index d770ff8c0..2fb0cf3b6 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -275,7 +275,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, ps = parameter_values(discs) for ts_idx in eachindex(discs) partition = discs[ts_idx] - interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + interp_val = ConstantInterpolation(partition.t, partition.u)( + t, nothing, deriv, nothing, continuity) ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val) end end @@ -296,7 +297,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect ps = parameter_values(discs) for ts_idx in eachindex(discs) partition = discs[ts_idx] - interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + interp_val = ConstantInterpolation(partition.t, partition.u)( + t, nothing, deriv, nothing, continuity) ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val) end end @@ -374,7 +376,9 @@ function get_saveable_values(sys, ps, timeseries_idx) end function save_discretes!(integ::DEIntegrator, timeseries_idx) - save_discretes!(integ.sol, current_time(integ), get_saveable_values(integ, parameter_values(integ), timeseries_idx), timeseries_idx) + save_discretes!(integ.sol, current_time(integ), + get_saveable_values(integ, parameter_values(integ), timeseries_idx), + timeseries_idx) end save_discretes!(args...) = nothing @@ -555,7 +559,8 @@ end mask_discretes(::Nothing, _, _...) = nothing -function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex}) +function mask_discretes( + discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex}) masked_discretes = map(discretes) do disc i = searchsortedlast(disc.t, new_t) disc[i:i] diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 57df61dbd..2d9ca7634 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -149,9 +149,9 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug function isdenseplot(sol) (sol.dense || sol.prob isa AbstractDiscreteProblem) && - !(sol isa AbstractRODESolution) && - !(hasfield(typeof(sol), :interp) && - sol.interp isa SensitivityInterpolation) + !(sol isa AbstractRODESolution) && + !(hasfield(typeof(sol), :interp) && + sol.interp isa SensitivityInterpolation) end @recipe function f(sol::AbstractTimeseriesSolution; @@ -187,7 +187,8 @@ end disc_vars = Tuple[] cont_vars = Tuple[] for var in vars - tsidxs = union(get_all_timeseries_indexes(sol, var[2]), get_all_timeseries_indexes(sol, var[3])) + tsidxs = union(get_all_timeseries_indexes(sol, var[2]), + get_all_timeseries_indexes(sol, var[3])) if ContinuousTimeseries() in tsidxs push!(cont_vars, var) else @@ -209,7 +210,6 @@ end plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] @@ -311,12 +311,12 @@ end seriestype := :line linestyle --> :dash markershape --> :o - markersize --> repeat([2, 0], length(ts)-1) - markeralpha --> repeat([1, 0], length(ts)-1) + markersize --> repeat([2, 0], length(ts) - 1) + markeralpha --> repeat([1, 0], length(ts) - 1) label --> string(hasname(yvar) ? getname(yvar) : yvar) - x = vec([xvals[1:end-1]'; xvals[2:end]']) - y = repeat(yvals, inner=2)[1:end-1] + x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) + y = repeat(yvals, inner = 2)[1:(end - 1)] x, y end end diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 390b500ed..56afc58ca 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -556,7 +556,8 @@ end newx = [] for i in eachindex(x) if x[i] isa Symbol - push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)]) + push!(newx, + allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)]) else push!(newx, x[i]) end @@ -590,7 +591,8 @@ end newx = [] for i in eachindex(x) if x[i] isa Symbol - push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)]) + push!(newx, + allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)]) else push!(newx, x[i]) end @@ -896,9 +898,11 @@ end @test_nowarn plot(sol; idxs = idx) @test_nowarn plot(sol; idxs = [idx]) end - for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs))) + for idx in Iterators.flatten(( + Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs))) @test_nowarn plot(sol; idxs = collect(idx)) - if !(idx[1] isa Tuple || idx[2] isa Tuple || length(get_all_timeseries_indexes(sol, collect(idx))) > 1) + if !(idx[1] isa Tuple || idx[2] isa Tuple || + length(get_all_timeseries_indexes(sol, collect(idx))) > 1) @test_nowarn plot(sol; idxs = idx) end end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 9d7a8301f..6cc5fd3b1 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -156,14 +156,15 @@ sol10 = sol(0.1, idxs = 2) plotfn(t, u) = (t, 2u) all_idxs = [x, x + p * y, t, (plotfn, 0, 1), (plotfn, t, 1), (plotfn, 0, x), - (plotfn, t, x), (plotfn, t, p * y)] + (plotfn, t, x), (plotfn, t, p * y)] sym_idxs = [:x, :t, (plotfn, :t, 1), (plotfn, 0, :x), - (plotfn, :t, :x)] + (plotfn, :t, :x)] for idx in Iterators.flatten((all_idxs, sym_idxs)) @test_nowarn plot(sol; idxs = idx) @test_nowarn plot(sol; idxs = [idx]) end - for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs))) + for idx in Iterators.flatten(( + Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs))) @test_nowarn plot(sol; idxs = collect(idx)) if !(idx[1] isa Tuple || idx[2] isa Tuple) @test_nowarn plot(sol; idxs = idx) diff --git a/test/serialization_tests.jl b/test/serialization_tests.jl index bcb2bf0c8..093af8b3a 100644 --- a/test/serialization_tests.jl +++ b/test/serialization_tests.jl @@ -3,11 +3,11 @@ using Serialization using Test for clock in [ - SciMLBase.Clock(0.5), - SciMLBase.Clock(0.5; phase = 0.1), - SciMLBase.SolverStepClock, - SciMLBase.Continuous, - ] + SciMLBase.Clock(0.5), + SciMLBase.Clock(0.5; phase = 0.1), + SciMLBase.SolverStepClock, + SciMLBase.Continuous +] serialize("_tmp.jls", clock) newclock = deserialize("_tmp.jls") @test newclock == clock