From 0841a5e7df96d18cc4882b044054351287e41c4f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 7 Sep 2024 07:52:07 -0400 Subject: [PATCH] ensemble error depwarn fixes --- Project.toml | 2 +- src/ensemble/ensemble_solutions.jl | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 752e585da..70d180a60 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.52.1" +version = "2.52.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 6dbca4b78..804657334 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -89,6 +89,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false, errors = Dict{Symbol, Vector{eltype(u[1].u[1])}}() #Should add type information error_means = Dict{Symbol, eltype(u[1].u[1])}() error_medians = Dict{Symbol, eltype(u[1].u[1])}() + + analyticvoa = u[1].u_analytic isa AbstractVectorOfArray ? true : false + for k in keys(u[1].errors) errors[k] = [sol.errors[k] for sol in u] error_means[k] = mean(errors[k]) @@ -98,12 +101,24 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false, weak_errors = Dict{Symbol, eltype(u[1].u[1])}() # Final m_final = mean([s.u[end] for s in u]) - m_final_analytic = mean([s.u_analytic[end] for s in u]) + + if analyticvoa + m_final_analytic = mean([s.u_analytic.u[end] for s in u]) + else + m_final_analytic = mean([s.u_analytic[end] for s in u]) + end + res = norm(m_final - m_final_analytic) weak_errors[:weak_final] = res if weak_timeseries_errors - ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)]) - for i in 1:length(u[1])] + + if analyticvoa + ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic.u[i] for j in 1:length(u)]) + for i in 1:length(u[1])] + else + ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)]) + for i in 1:length(u[1])] + end ts_l2_errors = [sqrt.(sum(abs2, err) / length(err)) for err in ts_weak_errors] l2_tmp = sqrt(sum(abs2, ts_l2_errors) / length(ts_l2_errors)) max_tmp = maximum([maximum(abs.(err)) for err in ts_weak_errors]) @@ -113,8 +128,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false, if weak_dense_errors densetimes = collect(range(u[1].t[1], stop = u[1].t[end], length = 100)) u_analytic = [[sol.prob.f.analytic(sol.prob.u0, sol.prob.p, densetimes[i], - sol.W(densetimes[i])[1]) - for i in eachindex(densetimes)] for sol in u] + sol.W(densetimes[i])[1]) + for i in eachindex(densetimes)] for sol in u] + udense = [u[j](densetimes) for j in 1:length(u)] dense_weak_errors = [mean([udense[j].u[i] - u_analytic[j][i] for j in 1:length(u)]) for i in eachindex(densetimes)]