Skip to content

Commit

Permalink
refactor: add DiffEqArray constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 30, 2024
1 parent 52d8c3f commit bb07a04
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ ChainRulesCore = "1"
ForwardDiff = "0.10.3"
MacroTools = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2,3"
RecursiveArrayTools = "3"
StaticArrays = "1.0"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
1 change: 1 addition & 0 deletions src/LabelledArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import RecursiveArrayTools, PreallocationTools, ForwardDiff
include("slarray.jl")
include("larray.jl")
include("chainrules.jl")
include("diffeqarray.jl")

# Common
@generated function __getindex(x::Union{LArray, SLArray}, ::Val{s}) where {s}
Expand Down
7 changes: 7 additions & 0 deletions src/diffeqarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
for LArrayType in [LArray, SLArray]
@eval function RecursiveArrayTools.DiffEqArray(vec::AbstractVector{<:$LArrayType},

Check warning on line 2 in src/diffeqarray.jl

View check run for this annotation

Codecov / codecov/patch

src/diffeqarray.jl#L2

Added line #L2 was not covered by tests
ts::AbstractVector,
p = nothing)
RecursiveArrayTools.DiffEqArray(vec, ts, p; variables = collect(symbols(vec[1])))

Check warning on line 5 in src/diffeqarray.jl

View check run for this annotation

Codecov / codecov/patch

src/diffeqarray.jl#L5

Added line #L5 was not covered by tests
end
end
6 changes: 6 additions & 0 deletions test/recursivearraytools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using RecursiveArrayTools, LabelledArrays, Test

ABC = @SLVector (:a, :b, :c);
A = ABC(1, 2, 3);
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
@test getindex(B, :a) == [1, 1]
32 changes: 21 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@ using StaticArrays
using InteractiveUtils
using ChainRulesTestUtils

@time begin
@time @testset "SLArrays" begin
include("slarrays.jl")
end
@time @testset "LArrays" begin
include("larrays.jl")
end
@time @testset "DiffEq" begin
include("diffeq.jl")
const GROUP = get(ENV, "GROUP", "All")

if GROUP == "All"
@time begin
@time @testset "SLArrays" begin
include("slarrays.jl")
end
@time @testset "LArrays" begin
include("larrays.jl")
end
@time @testset "DiffEq" begin
include("diffeq.jl")
end
@time @testset "ChainRules" begin
include("chainrules.jl")
end
end
@time @testset "ChainRules" begin
include("chainrules.jl")
end

if GROUP == "All" || GROUP == "RecursiveArrayTools"
@time @testset "RecursiveArrayTools" begin
include("recursivearraytools.jl")
end
end

0 comments on commit bb07a04

Please sign in to comment.