From db5270048b60337d6aa4b7554930fd5d70950c43 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 27 Dec 2023 15:40:24 -0500 Subject: [PATCH] fix analytical plots --- src/solutions/solution_interface.jl | 56 +++++++++++++---------------- test/solution_interface.jl | 8 ++--- 2 files changed, 26 insertions(+), 38 deletions(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 6962f9a0a..56486edb4 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -412,34 +412,14 @@ function add_labels!(labels, x, dims, sol, strs) end function add_analytic_labels!(labels, x, dims, sol, strs) - lys = [] - for j in 3:dims - if x[j] == 0 && dims == 2 - push!(lys, "$(getindepsym_defaultt(sol)),") - else - if strs !== nothing - push!(lys, string("True ", strs[x[j]], ",")) - else - push!(lys, "True u$(x[j]),") - end - end - end - lys[end] = lys[end][1:(end - 1)] # Take off the last comma - if x[2] == 0 - tmp_lab = "$(lys...)($(getindepsym_defaultt(sol)))" - else - if strs !== nothing - tmp = string("True ", strs[x[2]]) - tmp_lab = "($tmp,$(lys...))" - else - tmp_lab = "(True u$(x[2]),$(lys...))" - end - end - if x[1] != DEFAULT_PLOT_FUNC - push!(labels, "$(x[1])$(tmp_lab)") + if ((x[2] isa Integer && x[2] == 0) || isequal(x[2],getindepsym_defaultt(sol))) && dims == 2 + push!(labels, "True $(strs[end])") + elseif x[1] !== DEFAULT_PLOT_FUNC + push!(labels, "True f($(join(strs, ',')))") else - push!(labels, tmp_lab) + push!(labels, "True ($(join(strs, ',')))") end + labels end function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, @@ -486,27 +466,39 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, for x in vars tmp = [] strs = String[] - for j in 2:dims + for j in 2:length(x) if (x[j] isa Integer && x[j] == 0) push!(tmp, plott) push!(strs, "t") elseif isequal(x[j],getindepsym_defaultt(sol)) push!(tmp, plott) push!(strs, String(getname(x[j]))) - elseif n == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition}) - push!(tmp,timeseries) - push!(strs, String(getname(x[j]))) + elseif x[j] == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition}) + push!(tmp,plot_analytic_timeseries) + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end else _tmp = Vector{eltype(sol[1])}(undef, length(plot_timeseries)) for j in 1:length(plot_timeseries) _tmp[j] = plot_timeseries[j][n] end push!(tmp,_tmp) - push!(strs, String(getname(x[j]))) + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end end end f = x[1] - tmp = f.(tmp...) + tmp = map(f,tmp...) tmp = tuple((getindex.(tmp, i) for i in eachindex(tmp[1]))...) for i in eachindex(tmp) push!(plot_vecs[i], tmp[i]) diff --git a/test/solution_interface.jl b/test/solution_interface.jl index 798aa31d7..ed8edb31b 100644 --- a/test/solution_interface.jl +++ b/test/solution_interface.jl @@ -14,18 +14,14 @@ end push!(sol.u, ode.u0) end - syms = SciMLBase.interpret_vars(nothing, sol, SciMLBase.getsyms(sol)) - int_vars = SciMLBase.interpret_vars(nothing, sol, syms) # nothing = idxs + int_vars = SciMLBase.interpret_vars(nothing, sol) # nothing = idxs plot_vecs, labels = SciMLBase.diffeq_to_arrays(sol, true, # plot_analytic true, # denseplot 10, # plotdensity ode.tspan, - 0.1, # axis_safety - nothing, # idxs int_vars, - :identity, # tscale - nothing) # strs + :identity) # tscale @test plot_vecs[2][:, 2] ≈ @. exp(-plot_vecs[1][:, 2]) end