Skip to content

Commit

Permalink
Add function wrap_array
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Jun 19, 2024
1 parent a4d8aef commit 8e8efa8
Showing 1 changed file with 17 additions and 28 deletions.
45 changes: 17 additions & 28 deletions src/general/semidiscretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,52 +350,41 @@ end

# We have to pass `system` here for type stability,
# since the type of `system` determines the return type.
@inline function wrap_u(u_ode::Array, system, semi)
(; ranges_u) = semi

range = ranges_u[system_indices(system, semi)]

@boundscheck @assert length(range) == u_nvariables(system) * n_moving_particles(system)

# This is a non-allocating version of:
# return unsafe_wrap(Array{eltype(u_ode), 2}, pointer(view(u_ode, range)),
# (u_nvariables(system), n_moving_particles(system)))
return PtrArray(pointer(view(u_ode, range)),
(StaticInt(u_nvariables(system)), n_moving_particles(system)))
end

@inline function wrap_v(v_ode::Array, system, semi)
(; ranges_v) = semi

range = ranges_v[system_indices(system, semi)]

@boundscheck @assert length(range) == v_nvariables(system) * n_moving_particles(system)

return PtrArray(pointer(view(v_ode, range)),
(StaticInt(v_nvariables(system)), n_moving_particles(system)))
return wrap_array(v_ode, range,
(StaticInt(v_nvariables(system)), n_moving_particles(system)))
end

# For non-`Array`s (typically GPU arrays), just reshape. Calling the `PtrArray` code above
# for a `CuArray` yields another `CuArray` (instead of a `PtrArray`) and is 8 times slower
# with double the allocations.
@inline function wrap_u(u_ode, system, semi)
@inline function wrap_u(u_ode::Array, system, semi)
(; ranges_u) = semi

range = ranges_u[system_indices(system, semi)]

@boundscheck @assert length(range) == u_nvariables(system) * n_moving_particles(system)

return reshape(view(u_ode, range), (u_nvariables(system), n_moving_particles(system)))
return wrap_array(u_ode, range,
(StaticInt(u_nvariables(system)), n_moving_particles(system)))
end

@inline function wrap_v(v_ode, system, semi)
(; ranges_v) = semi

range = ranges_v[system_indices(system, semi)]

@boundscheck @assert length(range) == v_nvariables(system) * n_moving_particles(system)
@inline function wrap_array(array::Array, range, size)
# This is a non-allocating version of:
# return unsafe_wrap(Array{eltype(array), 2}, pointer(view(array, range)), size)
return PtrArray(pointer(view(array, range)), size)
end

return reshape(view(v_ode, range), (v_nvariables(system), n_moving_particles(system)))
@inline function wrap_array(array, range, size)
# For non-`Array`s (typically GPU arrays), just reshape. Calling the `PtrArray` code
# above for a `CuArray` yields another `CuArray` (instead of a `PtrArray`)
# and is 8 times slower with double the allocations.
#
# Note that `size` might contain `StaticInt`s, so convert to `Int` first.
return reshape(view(array, range), Int.(size))
end

function calculate_dt(v_ode, u_ode, cfl_number, semi::Semidiscretization)
Expand Down

0 comments on commit 8e8efa8

Please sign in to comment.