diff --git a/Project.toml b/Project.toml index 8c89b8a64..2565a0945 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,9 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" FastPow = "c0e83750-1142-43a8-81cf-6c956b72b4d1" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" PointNeighbors = "1c4d5385-0a27-49de-8e2c-43b175c8985c" @@ -35,6 +37,7 @@ FastPow = "0.1" ForwardDiff = "0.10" JSON = "0.21" MuladdMacro = "0.2" +PointNeighbors = "0.2.3" Polyester = "0.7.5" RecipesBase = "1" Reexport = "1" @@ -43,6 +46,5 @@ StaticArrays = "1" StrideArrays = "0.1" TimerOutputs = "0.5" TrixiBase = "0.1.3" -PointNeighbors = "0.2" WriteVTK = "1" julia = "1.9" diff --git a/examples/fluid/dam_break_2d.jl b/examples/fluid/dam_break_2d.jl index e58a7cc06..a1eeb2f38 100644 --- a/examples/fluid/dam_break_2d.jl +++ b/examples/fluid/dam_break_2d.jl @@ -73,8 +73,10 @@ boundary_system = BoundarySPHSystem(tank.boundary, boundary_model, adhesion_coef # ========================================================================================== # ==== Simulation -semi = Semidiscretization(fluid_system, boundary_system, threaded_nhs_update=true) -ode = semidiscretize(semi, tspan) +semi = Semidiscretization(fluid_system, boundary_system, + neighborhood_search=GridNeighborhoodSearch, + threaded_nhs_update=true) +ode = semidiscretize(semi, tspan, data_type=nothing) info_callback = InfoCallback(interval=100) diff --git a/examples/n_body/n_body_system.jl b/examples/n_body/n_body_system.jl index c7226357a..556147ed4 100644 --- a/examples/n_body/n_body_system.jl +++ b/examples/n_body/n_body_system.jl @@ -1,7 +1,9 @@ using TrixiParticles using LinearAlgebra -struct NBodySystem{NDIMS, ELTYPE <: Real} <: TrixiParticles.System{NDIMS, Nothing} +# The second type parameter of `System` can't be `Nothing`, or TrixiParticles will launch +# GPU kernel for `for_particle_neighbor` loops. +struct NBodySystem{NDIMS, ELTYPE <: Real} <: TrixiParticles.System{NDIMS, 0} initial_condition :: InitialCondition{ELTYPE} mass :: Array{ELTYPE, 1} # [particle] G :: ELTYPE diff --git a/src/TrixiParticles.jl b/src/TrixiParticles.jl index aa97ab720..9fe23e3d4 100644 --- a/src/TrixiParticles.jl +++ b/src/TrixiParticles.jl @@ -9,7 +9,9 @@ using DataFrames: DataFrame using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect, PresetTimeCallback using FastPow: @fastpow using ForwardDiff: ForwardDiff +using GPUArrays: AbstractGPUArray using JSON: JSON +using KernelAbstractions: KernelAbstractions, @kernel, @index using LinearAlgebra: norm, dot, I, tr, inv, pinv, det using MuladdMacro: @muladd using Polyester: Polyester, @batch @@ -27,7 +29,9 @@ using TrixiBase: trixi_include, @trixi_timeit, timer, timeit_debug_enabled, using PointNeighbors: PointNeighbors, for_particle_neighbor using WriteVTK: vtk_grid, MeshCell, VTKCellTypes, paraview_collection, vtk_save -# util needs to be first because of macro @trixi_timeit +# `util.jl` depends on the `GPUSystem` type defined in `system.jl` +include("general/system.jl") +# `util.jl` needs to be next because of the macros `@trixi_timeit` and `@threaded` include("util.jl") include("callbacks/callbacks.jl") include("general/general.jl") diff --git a/src/general/corrections.jl b/src/general/corrections.jl index cf64f6d16..ee894fec4 100644 --- a/src/general/corrections.jl +++ b/src/general/corrections.jl @@ -440,7 +440,7 @@ function compute_gradient_correction_matrix!(corr_matrix::AbstractArray, system, end function correction_matrix_inversion_step!(corr_matrix, system) - @threaded for particle in eachparticle(system) + @threaded system for particle in eachparticle(system) L = extract_smatrix(corr_matrix, system, particle) # The matrix `L` only becomes singular when the particle and all neighbors diff --git a/src/general/general.jl b/src/general/general.jl index 774d076c9..dfd709861 100644 --- a/src/general/general.jl +++ b/src/general/general.jl @@ -1,37 +1,11 @@ -# Abstract supertype for all system types. We additionally store the type of the system's -# initial condition, which is `Nothing` when using KernelAbstractions.jl. -abstract type System{NDIMS, IC} end - -# When using KernelAbstractions.jl, the initial condition has been replaced by `nothing` -GPUSystem = System{NDIMS, Nothing} where {NDIMS} - -abstract type FluidSystem{NDIMS, IC} <: System{NDIMS, IC} end -timer_name(::FluidSystem) = "fluid" -vtkname(system::FluidSystem) = "fluid" - -abstract type SolidSystem{NDIMS, IC} <: System{NDIMS, IC} end -timer_name(::SolidSystem) = "solid" -vtkname(system::SolidSystem) = "solid" - -abstract type BoundarySystem{NDIMS, IC} <: System{NDIMS, IC} end -timer_name(::BoundarySystem) = "boundary" -vtkname(system::BoundarySystem) = "boundary" - -@inline function set_zero!(du) - du .= zero(eltype(du)) - - return du -end - -# Note that `semidiscretization.jl` depends on the system types and has to be -# included later. +# Note that `system.jl` has already been included. +# `semidiscretization.jl` depends on the system types and has to be included later. # `density_calculators.jl` needs to be included before `corrections.jl`. include("density_calculators.jl") include("corrections.jl") include("smoothing_kernels.jl") include("initial_condition.jl") include("buffer.jl") -include("system.jl") include("interpolation.jl") include("custom_quantities.jl") include("neighborhood_search.jl") diff --git a/src/general/gpu.jl b/src/general/gpu.jl index 377ea8b72..8aa481cff 100644 --- a/src/general/gpu.jl +++ b/src/general/gpu.jl @@ -13,6 +13,7 @@ Adapt.@adapt_structure DensityDiffusionAntuono Adapt.@adapt_structure BoundarySPHSystem Adapt.@adapt_structure BoundaryModelDummyParticles Adapt.@adapt_structure BoundaryModelMonaghanKajtar +Adapt.@adapt_structure TotalLagrangianSPHSystem # The initial conditions are only used for initialization, which happens before `adapt`ing # the semidiscretization, so we don't need to store `InitialCondition`s on the GPU. @@ -32,3 +33,10 @@ end function Adapt.adapt_structure(to::typeof(Array), range::UnitRange) return range end + +KernelAbstractions.get_backend(::PtrArray) = KernelAbstractions.CPU() +KernelAbstractions.get_backend(system::System) = KernelAbstractions.get_backend(system.mass) + +function KernelAbstractions.get_backend(system::BoundarySPHSystem) + KernelAbstractions.get_backend(system.coordinates) +end diff --git a/src/general/neighborhood_search.jl b/src/general/neighborhood_search.jl index 0713f153d..afcd2c9b3 100644 --- a/src/general/neighborhood_search.jl +++ b/src/general/neighborhood_search.jl @@ -9,3 +9,14 @@ function PointNeighbors.for_particle_neighbor(f, system, neighbor_system, for_particle_neighbor(f, system_coords, neighbor_coords, neighborhood_search, particles=particles, parallel=parallel) end + +function PointNeighbors.for_particle_neighbor(f, system::GPUSystem, neighbor_system, + system_coords, neighbor_coords, + neighborhood_search; + particles=each_moving_particle(system), + parallel=true) + @threaded system for particle in particles + PointNeighbors.foreach_neighbor(f, system_coords, neighbor_coords, + neighborhood_search, particle) + end +end diff --git a/src/general/semidiscretization.jl b/src/general/semidiscretization.jl index 73c1fe56d..7f6c791db 100644 --- a/src/general/semidiscretization.jl +++ b/src/general/semidiscretization.jl @@ -368,6 +368,17 @@ end # We have to pass `system` here for type stability, # since the type of `system` determines the return type. +@inline function wrap_v(v_ode, system, semi) + (; ranges_v) = semi + + range = ranges_v[system_indices(system, semi)] + + @boundscheck @assert length(range) == v_nvariables(system) * n_moving_particles(system) + + return wrap_array(v_ode, range, + (StaticInt(v_nvariables(system)), n_moving_particles(system))) +end + @inline function wrap_u(u_ode, system, semi) (; ranges_u) = semi @@ -375,22 +386,23 @@ end @boundscheck @assert length(range) == u_nvariables(system) * n_moving_particles(system) - # This is a non-allocating version of: - # return unsafe_wrap(Array{eltype(u_ode), 2}, pointer(view(u_ode, range)), - # (u_nvariables(system), n_moving_particles(system))) - return PtrArray(pointer(view(u_ode, range)), - (StaticInt(u_nvariables(system)), n_moving_particles(system))) + return wrap_array(u_ode, range, + (StaticInt(u_nvariables(system)), n_moving_particles(system))) end -@inline function wrap_v(v_ode, system, semi) - (; ranges_v) = semi - - range = ranges_v[system_indices(system, semi)] - - @boundscheck @assert length(range) == v_nvariables(system) * n_moving_particles(system) +@inline function wrap_array(array::Array, range, size) + # This is a non-allocating version of: + # return unsafe_wrap(Array{eltype(array), 2}, pointer(view(array, range)), size) + return PtrArray(pointer(view(array, range)), size) +end - return PtrArray(pointer(view(v_ode, range)), - (StaticInt(v_nvariables(system)), n_moving_particles(system))) +@inline function wrap_array(array, range, size) + # For non-`Array`s (typically GPU arrays), just reshape. Calling the `PtrArray` code + # above for a `CuArray` yields another `CuArray` (instead of a `PtrArray`) + # and is 8 times slower with double the allocations. + # + # Note that `size` might contain `StaticInt`s, so convert to `Int` first. + return reshape(view(array, range), Int.(size)) end function calculate_dt(v_ode, u_ode, cfl_number, semi::Semidiscretization) @@ -409,7 +421,7 @@ function drift!(du_ode, v_ode, u_ode, semi, t) du = wrap_u(du_ode, system, semi) v = wrap_v(v_ode, system, semi) - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) # This can be dispatched per system add_velocity!(du, v, particle, system) end @@ -508,7 +520,7 @@ function add_source_terms!(dv_ode, v_ode, u_ode, semi) v = wrap_v(v_ode, system, semi) u = wrap_u(u_ode, system, semi) - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) # Dispatch by system type to exclude boundary systems add_acceleration!(dv, particle, system) add_source_terms_inner!(dv, v, u, particle, system, source_terms(system)) diff --git a/src/general/system.jl b/src/general/system.jl index c3f10caad..674555aea 100644 --- a/src/general/system.jl +++ b/src/general/system.jl @@ -1,3 +1,28 @@ +# Abstract supertype for all system types. We additionally store the type of the system's +# initial condition, which is `Nothing` when using KernelAbstractions.jl. +abstract type System{NDIMS, IC} end + +# When using KernelAbstractions.jl, the initial condition has been replaced by `nothing` +GPUSystem = System{NDIMS, Nothing} where {NDIMS} + +abstract type FluidSystem{NDIMS, IC} <: System{NDIMS, IC} end +timer_name(::FluidSystem) = "fluid" +vtkname(system::FluidSystem) = "fluid" + +abstract type SolidSystem{NDIMS, IC} <: System{NDIMS, IC} end +timer_name(::SolidSystem) = "solid" +vtkname(system::SolidSystem) = "solid" + +abstract type BoundarySystem{NDIMS, IC} <: System{NDIMS, IC} end +timer_name(::BoundarySystem) = "boundary" +vtkname(system::BoundarySystem) = "boundary" + +@inline function set_zero!(du) + du .= zero(eltype(du)) + + return du +end + initialize!(system, neighborhood_search) = system @inline Base.ndims(::System{NDIMS}) where {NDIMS} = NDIMS diff --git a/src/schemes/boundary/dummy_particles/dummy_particles.jl b/src/schemes/boundary/dummy_particles/dummy_particles.jl index 0f7b42696..3e3abcbd9 100644 --- a/src/schemes/boundary/dummy_particles/dummy_particles.jl +++ b/src/schemes/boundary/dummy_particles/dummy_particles.jl @@ -291,7 +291,7 @@ function compute_pressure!(boundary_model, ::Union{SummationDensity, ContinuityD # Limit pressure to be non-negative to avoid attractive forces between fluid and # boundary particles at free surfaces (sticking artifacts). - @threaded for particle in eachparticle(system) + @threaded system for particle in eachparticle(system) apply_state_equation!(boundary_model, particle_density(v, boundary_model, particle), particle) end @@ -346,14 +346,15 @@ function compute_pressure!(boundary_model, ::AdamiPressureExtrapolation, system_coords, neighbor_coords, v_neighbor_system, nhs) end - for particle in eachparticle(system) + + @threaded system for particle in eachparticle(system) # Limit pressure to be non-negative to avoid attractive forces between fluid and # boundary particles at free surfaces (sticking artifacts). pressure[particle] = max(pressure[particle], 0.0) end end - @trixi_timeit timer() "inverse state equation" @threaded for particle in eachparticle(system) + @trixi_timeit timer() "inverse state equation" @threaded system for particle in eachparticle(system) compute_adami_density!(boundary_model, system, system_coords, particle) end end diff --git a/src/schemes/boundary/open_boundary/system.jl b/src/schemes/boundary/open_boundary/system.jl index 43c0a6ff1..b5950856e 100644 --- a/src/schemes/boundary/open_boundary/system.jl +++ b/src/schemes/boundary/open_boundary/system.jl @@ -171,7 +171,7 @@ end reference_velocity, reference_pressure, reference_density) = system # Update quantities based on the characteristic variables - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) particle_position = current_coords(u, system, particle) J1 = characteristics[1, particle] @@ -230,7 +230,7 @@ function evaluate_characteristics!(system, v, u, v_ode, u_ode, semi, t) # Thus, we compute the characteristics for the particles that are outside the influence # of fluid particles by using the average of the values of the previous time step. # See eq. 27 in Negi (2020) https://doi.org/10.1016/j.cma.2020.113119 - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) # Particle is outside of the influence of fluid particles if isapprox(volume[particle], 0.0) diff --git a/src/schemes/boundary/system.jl b/src/schemes/boundary/system.jl index 73af6a843..f36d515ac 100644 --- a/src/schemes/boundary/system.jl +++ b/src/schemes/boundary/system.jl @@ -27,13 +27,13 @@ struct BoundarySPHSystem{BM, NDIMS, ELTYPE <: Real, IC, CO, M, IM, # This constructor is necessary for Adapt.jl to work with this struct. # See the comments in general/gpu.jl for more details. function BoundarySPHSystem(initial_condition, coordinates, boundary_model, movement, - ismoving, adhesion_coefficient, cache) + ismoving, adhesion_coefficient, cache, buffer) ELTYPE = eltype(coordinates) new{typeof(boundary_model), size(coordinates, 1), ELTYPE, typeof(initial_condition), typeof(coordinates), typeof(movement), typeof(ismoving), typeof(cache)}(initial_condition, coordinates, boundary_model, movement, - ismoving, adhesion_coefficient, cache, nothing) + ismoving, adhesion_coefficient, cache, buffer) end end @@ -54,7 +54,7 @@ function BoundarySPHSystem(initial_condition, model; movement=nothing, # Because of dispatches boundary model needs to be first! return BoundarySPHSystem(initial_condition, coordinates, model, movement, - ismoving, adhesion_coefficient, cache) + ismoving, adhesion_coefficient, cache, nothing) end """ @@ -209,7 +209,7 @@ function (movement::BoundaryMovement)(system, t) is_moving(t) || return system - @threaded for particle in moving_particles + @threaded system for particle in moving_particles pos_new = initial_coords(system, particle) + movement_function(t) vel = ForwardDiff.derivative(movement_function, t) acc = ForwardDiff.derivative(t_ -> ForwardDiff.derivative(movement_function, t_), t) diff --git a/src/schemes/fluid/weakly_compressible_sph/system.jl b/src/schemes/fluid/weakly_compressible_sph/system.jl index a8e28670b..44e71695b 100644 --- a/src/schemes/fluid/weakly_compressible_sph/system.jl +++ b/src/schemes/fluid/weakly_compressible_sph/system.jl @@ -311,7 +311,7 @@ function reinit_density!(system, v, u, v_ode, u_ode, semi) end function compute_pressure!(system, v) - @threaded for particle in eachparticle(system) + @threaded system for particle in eachparticle(system) apply_state_equation!(system, particle_density(v, system, particle), particle) end end diff --git a/src/schemes/solid/total_lagrangian_sph/system.jl b/src/schemes/solid/total_lagrangian_sph/system.jl index 9203425fb..125bb5507 100644 --- a/src/schemes/solid/total_lagrangian_sph/system.jl +++ b/src/schemes/solid/total_lagrangian_sph/system.jl @@ -70,55 +70,50 @@ struct TotalLagrangianSPHSystem{BM, NDIMS, ELTYPE <: Real, IC, ARRAY1D, ARRAY2D, penalty_force :: PF source_terms :: ST buffer :: Nothing +end - function TotalLagrangianSPHSystem(initial_condition, - smoothing_kernel, smoothing_length, - young_modulus, poisson_ratio; - n_fixed_particles=0, boundary_model=nothing, - acceleration=ntuple(_ -> 0.0, - ndims(smoothing_kernel)), - penalty_force=nothing, source_terms=nothing) - NDIMS = ndims(initial_condition) - ELTYPE = eltype(initial_condition) - n_particles = nparticles(initial_condition) - - if ndims(smoothing_kernel) != NDIMS - throw(ArgumentError("smoothing kernel dimensionality must be $NDIMS for a $(NDIMS)D problem")) - end +function TotalLagrangianSPHSystem(initial_condition, + smoothing_kernel, smoothing_length, + young_modulus, poisson_ratio; + n_fixed_particles=0, boundary_model=nothing, + acceleration=ntuple(_ -> 0.0, + ndims(smoothing_kernel)), + penalty_force=nothing, source_terms=nothing) + NDIMS = ndims(initial_condition) + ELTYPE = eltype(initial_condition) + n_particles = nparticles(initial_condition) - # Make acceleration an SVector - acceleration_ = SVector(acceleration...) - if length(acceleration_) != NDIMS - throw(ArgumentError("`acceleration` must be of length $NDIMS for a $(NDIMS)D problem")) - end + if ndims(smoothing_kernel) != NDIMS + throw(ArgumentError("smoothing kernel dimensionality must be $NDIMS for a $(NDIMS)D problem")) + end - initial_coordinates = copy(initial_condition.coordinates) - current_coordinates = copy(initial_condition.coordinates) - mass = copy(initial_condition.mass) - material_density = copy(initial_condition.density) - correction_matrix = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) - pk1_corrected = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) - deformation_grad = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) - - n_moving_particles = n_particles - n_fixed_particles - - lame_lambda = young_modulus * poisson_ratio / - ((1 + poisson_ratio) * (1 - 2 * poisson_ratio)) - lame_mu = 0.5 * young_modulus / (1 + poisson_ratio) - - return new{typeof(boundary_model), NDIMS, ELTYPE, - typeof(initial_condition), - typeof(mass), typeof(initial_coordinates), - typeof(deformation_grad), typeof(smoothing_kernel), - typeof(penalty_force), - typeof(source_terms)}(initial_condition, initial_coordinates, - current_coordinates, mass, correction_matrix, - pk1_corrected, deformation_grad, material_density, - n_moving_particles, young_modulus, poisson_ratio, - lame_lambda, lame_mu, smoothing_kernel, - smoothing_length, acceleration_, boundary_model, - penalty_force, source_terms, nothing) + # Make acceleration an SVector + acceleration_ = SVector(acceleration...) + if length(acceleration_) != NDIMS + throw(ArgumentError("`acceleration` must be of length $NDIMS for a $(NDIMS)D problem")) end + + initial_coordinates = copy(initial_condition.coordinates) + current_coordinates = copy(initial_condition.coordinates) + mass = copy(initial_condition.mass) + material_density = copy(initial_condition.density) + correction_matrix = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) + pk1_corrected = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) + deformation_grad = Array{ELTYPE, 3}(undef, NDIMS, NDIMS, n_particles) + + n_moving_particles = n_particles - n_fixed_particles + + lame_lambda = young_modulus * poisson_ratio / + ((1 + poisson_ratio) * (1 - 2 * poisson_ratio)) + lame_mu = 0.5 * young_modulus / (1 + poisson_ratio) + + return TotalLagrangianSPHSystem(initial_condition, initial_coordinates, + current_coordinates, mass, correction_matrix, + pk1_corrected, deformation_grad, material_density, + n_moving_particles, young_modulus, poisson_ratio, + lame_lambda, lame_mu, smoothing_kernel, + smoothing_length, acceleration_, boundary_model, + penalty_force, source_terms, nothing) end function Base.show(io::IO, system::TotalLagrangianSPHSystem) @@ -260,7 +255,7 @@ end calc_deformation_grad!(deformation_grad, neighborhood_search, system) - @threaded for particle in eachparticle(system) + @threaded system for particle in eachparticle(system) F_particle = deformation_gradient(system, particle) pk1_particle = pk1_stress_tensor(F_particle, system) pk1_particle_corrected = pk1_particle * correction_matrix(system, particle) @@ -397,7 +392,7 @@ end function von_mises_stress(system::TotalLagrangianSPHSystem) von_mises_stress_vector = zeros(eltype(system.pk1_corrected), nparticles(system)) - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) von_mises_stress_vector[particle] = von_mises_stress(system, particle) end @@ -430,7 +425,7 @@ function cauchy_stress(system::TotalLagrangianSPHSystem) cauchy_stress_tensors = zeros(eltype(system.pk1_corrected), NDIMS, NDIMS, nparticles(system)) - @threaded for particle in each_moving_particle(system) + @threaded system for particle in each_moving_particle(system) F = deformation_gradient(system, particle) J = det(F) P = pk1_corrected(system, particle) diff --git a/src/util.jl b/src/util.jl index 338a94415..172f0a800 100644 --- a/src/util.jl +++ b/src/util.jl @@ -27,57 +27,74 @@ function print_startup_message() end """ - @threaded for ... end + @threaded system for ... end +Run either a threaded CPU loop or launch a kernel on the GPU, depending on the type of `system`. Semantically the same as `Threads.@threads` when iterating over a `AbstractUnitRange` but without guarantee that the underlying implementation uses `Threads.@threads` or works for more general `for` loops. -In particular, there may be an additional check whether only one thread is used -to reduce the overhead of serial execution or the underlying threading capabilities -might be provided by other packages such as [Polyester.jl](https://github.com/JuliaSIMD/Polyester.jl). +The first argument must either be a particle system or an array from which can be derived +if the loop has to be run threaded on the CPU or launched as a kernel on the GPU. + +In particular, the underlying threading capabilities might be provided by other packages +such as [Polyester.jl](https://github.com/JuliaSIMD/Polyester.jl). !!! warn This macro does not necessarily work for general `for` loops. For example, it does not necessarily support general iterables such as `eachline(filename)`. - -Some discussion can be found at -[https://discourse.julialang.org/t/overhead-of-threads-threads/53964](https://discourse.julialang.org/t/overhead-of-threads-threads/53964) -and -[https://discourse.julialang.org/t/threads-threads-with-one-thread-how-to-remove-the-overhead/58435](https://discourse.julialang.org/t/threads-threads-with-one-thread-how-to-remove-the-overhead/58435). - -Copied from [Trixi.jl](https://github.com/trixi-framework/Trixi.jl). """ -macro threaded(expr) - # Use `esc(quote ... end)` for nested macro calls as suggested in - # https://github.com/JuliaLang/julia/issues/23221 - # - # The following code is a simple version using only `Threads.@threads` from the - # standard library with an additional check whether only a single thread is used - # to reduce some overhead (and allocations) for serial execution. - # - # return esc(quote - # let - # if Threads.nthreads() == 1 - # $(expr) - # else - # Threads.@threads $(expr) - # end - # end - # end) - # - # However, the code below using `@batch` from Polyester.jl is more efficient, - # since this packages provides threads with less overhead. Since it is written - # by Chris Elrod, the author of LoopVectorization.jl, we expect this package - # to provide the most efficient and useful implementation of threads (as we use - # them) available in Julia. - # !!! danger "Heisenbug" - # Look at the comments for `wrap_array` when considering to change this macro. - +macro threaded(system, expr) + # Reverse-engineer the for loop. + # `expr.args[1]` is the head of the for loop, like `i = eachindex(x)`. + # So, `expr.args[1].args[2]` is the iterator `eachindex(x)` + # and `expr.args[1].args[1]` is the loop variable `i`. + iterator = expr.args[1].args[2] + i = expr.args[1].args[1] + inner_loop = expr.args[2] + + # Assemble the for loop again as a call to `parallel_foreach`, using `$i` to use the + # same loop variable as used in the for loop. return esc(quote - TrixiParticles.@batch $(expr) + TrixiParticles.parallel_foreach($iterator, $system) do $i + $inner_loop + end end) end +# Use `Polyester.@batch` for low-overhead threading +@inline function parallel_foreach(f, iterator, system) + Polyester.@batch for i in iterator + @inline f(i) + end +end + +# On GPUs, execute `f` inside a GPU kernel with KernelAbstractions.jl +@inline function parallel_foreach(f, iterator, system::Union{GPUSystem, AbstractGPUArray}) + # On the GPU, we can only loop over `1:N`. Therefore, we loop over `1:length(iterator)` + # and index with `iterator[eachindex(iterator)[i]]`. + # Note that this only works with vector-like iterators that support arbitrary indexing. + indices = eachindex(iterator) + ndrange = length(indices) + + # Skip empty loops + ndrange == 0 && return + + backend = KernelAbstractions.get_backend(system) + + # Call the generic kernel that is defined below, which only calls a function with + # the global GPU index. + generic_kernel(backend)(ndrange=ndrange) do i + @inline f(iterator[indices[i]]) + end + + KernelAbstractions.synchronize(backend) +end + +@kernel function generic_kernel(f) + i = @index(Global) + @inline f(i) +end + @doc raw""" examples_dir() diff --git a/test/count_allocations.jl b/test/count_allocations.jl index 1c597f9a0..5957bc67a 100644 --- a/test/count_allocations.jl +++ b/test/count_allocations.jl @@ -16,14 +16,15 @@ function copy_semi_with_no_update_nhs(semi) neighborhood_searches) end -# Forward `for_particle_neighbor` to wrapped neighborhood search -@inline function TrixiParticles.for_particle_neighbor(f, system_coords, neighbor_coords, - neighborhood_search::NoUpdateNeighborhoodSearch; - particles=axes(system_coords, 2), - parallel=true) - TrixiParticles.for_particle_neighbor(f, system_coords, neighbor_coords, - neighborhood_search.nhs, - particles=particles, parallel=parallel) +# Forward `foreach_neighbor` to wrapped neighborhood search +@inline function TrixiParticles.PointNeighbors.foreach_neighbor(f, system_coords, + neighbor_coords, + neighborhood_search::NoUpdateNeighborhoodSearch, + particle; + search_radius=neighborhood_search.nhs.search_radius) + TrixiParticles.PointNeighbors.foreach_neighbor(f, system_coords, neighbor_coords, + neighborhood_search.nhs, particle, + search_radius=search_radius) end # No update diff --git a/test/examples/examples.jl b/test/examples/examples.jl index dbe03a130..19b6ba2b9 100644 --- a/test/examples/examples.jl +++ b/test/examples/examples.jl @@ -87,6 +87,19 @@ @test count_rhs_allocations(sol, semi) == 0 end + @trixi_testset "fluid/dam_break_2d.jl with KernelAbstractions.jl" begin + # Emulate the GPU code on the CPU by passing `data_type = Array` + @test_nowarn_mod trixi_include(@__MODULE__, + joinpath(examples_dir(), "fluid", + "dam_break_2d.jl"), tspan=(0.0, 0.1), + data_type=Array) [ + r"┌ Info: The desired tank length in y-direction .*\n", + r"└ New tank length in y-direction.*\n", + ] + @test sol.retcode == ReturnCode.Success + @test count_rhs_allocations(sol, semi) == 0 + end + @trixi_testset "fluid/dam_break_3d.jl" begin @test_nowarn_mod trixi_include(@__MODULE__, joinpath(examples_dir(), "fluid", diff --git a/test/general/semidiscretization.jl b/test/general/semidiscretization.jl index 4aff1f4fd..457b5b84d 100644 --- a/test/general/semidiscretization.jl +++ b/test/general/semidiscretization.jl @@ -1,9 +1,15 @@ # Use `@trixi_testset` to isolate the mock functions in a separate namespace @trixi_testset "Semidiscretization" begin - # Mock systems - struct System1 <: TrixiParticles.System{3, Nothing} end + # Mock systems. `System1` will use the CPU backend, `System2` is a `GPUSystem`, using + # the GPU backend (emulated on the CPU). + struct System1 <: TrixiParticles.System{3, String} end struct System2 <: TrixiParticles.System{3, Nothing} end + # `System2` has no field `mass`, so we have to manually define the backend + function TrixiParticles.KernelAbstractions.get_backend(::System2) + return TrixiParticles.KernelAbstractions.CPU() + end + system1 = System1() system2 = System2() diff --git a/test/schemes/solid/total_lagrangian_sph/rhs.jl b/test/schemes/solid/total_lagrangian_sph/rhs.jl index dfc9fc234..b64fb6c97 100644 --- a/test/schemes/solid/total_lagrangian_sph/rhs.jl +++ b/test/schemes/solid/total_lagrangian_sph/rhs.jl @@ -32,7 +32,7 @@ [0.0, 0.0], ] - @testset "Test $i" for i in 1:4 + @testset verbose=true "Test $i" for i in 1:4 #### Setup each_moving_particle = [particle[i]] # Only calculate dv for this one particle eachparticle = [particle[i], neighbor[i]] @@ -53,18 +53,27 @@ kernel_deriv = 1.0 #### Mocking - # Mock the system - system = Val{:mock_system_interact}() - TrixiParticles.ndims(::Val{:mock_system_interact}) = 2 - Base.ntuple(f, ::Symbol) = ntuple(f, 2) # Make `extract_svector` work + # Mock a CPU system to test CPU code + struct MockSystemInteractCPU <: TrixiParticles.System{2, String} end + system = MockSystemInteractCPU() - function TrixiParticles.initial_coordinates(::Val{:mock_system_interact}) + # Mock a GPU system to emulate GPU code on the CPU + struct MockSystemInteractGPU <: TrixiParticles.System{2, Nothing} end + system_gpu = MockSystemInteractGPU() + + function TrixiParticles.KernelAbstractions.get_backend(::MockSystemInteractGPU) + return TrixiParticles.KernelAbstractions.CPU() + end + + MockSystemType = Union{MockSystemInteractCPU, MockSystemInteractGPU} + + function TrixiParticles.initial_coordinates(::MockSystemType) return initial_coordinates end # Unpack calls should return predefined values or # another mock object of the type Val{:mock_property_name}. - function Base.getproperty(::Val{:mock_system_interact}, f::Symbol) + function Base.getproperty(::MockSystemType, f::Symbol) if f === :current_coordinates return current_coordinates elseif f === :material_density @@ -81,10 +90,13 @@ return Val(Symbol("mock_" * string(f))) end - function TrixiParticles.each_moving_particle(::Val{:mock_system_interact}) + function TrixiParticles.each_moving_particle(::MockSystemType) each_moving_particle end - TrixiParticles.eachparticle(::Val{:mock_system_interact}) = eachparticle + TrixiParticles.eachparticle(::MockSystemType) = eachparticle + + # Mock the neighborhood search + nhs = Val{:nhs}() TrixiParticles.PointNeighbors.eachneighbor(_, ::Val{:nhs}) = eachneighbor function Base.getproperty(::Val{:nhs}, f::Symbol) @@ -99,26 +111,25 @@ end TrixiParticles.ndims(::Val{:nhs}) = 2 - function TrixiParticles.pk1_corrected(::Val{:mock_system_dv}, particle_) - if particle_ == particle[i] - return pk1_particle_corrected[i] - end - return pk1_neighbor_corrected[i] - end - - function TrixiParticles.add_acceleration!(_, _, ::Val{:mock_system_interact}) + function TrixiParticles.add_acceleration!(_, _, ::MockSystemType) nothing end TrixiParticles.kernel_deriv(::Val{:mock_smoothing_kernel}, _, _) = kernel_deriv #### Verification - dv = zeros(ndims(system), 10) - dv_expected = copy(dv) - dv_expected[:, particle[i]] = dv_particle_expected[i] + systems = [system, system_gpu] + names = ["CPU code", "Emulate GPU"] + @testset "$(names[j])" for j in eachindex(names) + system_ = systems[j] + + dv = zeros(ndims(system_), 10) + dv_expected = copy(dv) + dv_expected[:, particle[i]] = dv_particle_expected[i] - TrixiParticles.interact_solid_solid!(dv, Val(:nhs), system, system) + TrixiParticles.interact_solid_solid!(dv, nhs, system_, system_) - @test dv ≈ dv_expected + @test dv ≈ dv_expected + end end end @@ -140,7 +151,7 @@ 10 / 1000^2 * 1.5400218087591082 * 324.67072684047224 * 1.224, 0.0, ]) - @testset "Deformation Function: $deformation" for deformation in keys(deformations) + @testset verbose=true "Deformation Function: $deformation" for deformation in keys(deformations) J = deformations[deformation] u = zeros(2, 81) v = zeros(2, 81) @@ -176,22 +187,50 @@ semi = Semidiscretization(system) tspan = (0.0, 1.0) - semidiscretize(semi, tspan) - # Apply the deformation matrix - for particle in axes(u, 2) - # Apply deformation - u[1:2, particle] = deformations[deformation](coordinates[:, particle]) + # To make the code below work + function TrixiParticles.PtrArray{Float64}(::UndefInitializer, length) + TrixiParticles.PtrArray(zeros(length)) end - #### Verification for the particle in the middle - particle = 41 + # We can pass the data type `Array` to convert all systems to `GPUSystem`s + # and emulate the GPU kernels on the GPU. + # But this doesn't test `wrap_v` and `wrap_u` for non-`Array` types. + # In order to test this as well, we need a different data type, so we also + # pass `PtrArray`. + names = ["CPU code", "GPU code with CPU wrapping", "GPU code with GPU wrapping"] + data_types = [nothing, Array, TrixiParticles.PtrArray] + @testset "$(names[i])" for i in eachindex(names) + data_type = data_types[i] + ode = semidiscretize(semi, tspan, data_type=data_type) + + # Apply the deformation matrix + for particle in axes(u, 2) + # Apply deformation + u[1:2, particle] = deformations[deformation](coordinates[:, particle]) + end + + v_ode = ode.u0.x[1] + if isnothing(data_type) + u_ode = vec(u) + else + u_ode = data_type(vec(u)) + end + + @test typeof(v_ode) == typeof(u_ode) + @test length(v_ode) == length(u_ode) + + #### Verification for the particle in the middle + particle = 41 - dv = zeros(ndims(system), 81) - TrixiParticles.kick!(dv, v, u, semi, 0.0) + dv_ode = zero(v_ode) + TrixiParticles.kick!(dv_ode, v_ode, u_ode, ode.p, 0.0) - @test isapprox(dv[:, particle], dv_expected_41[deformation], - rtol=sqrt(eps()), atol=sqrt(eps())) + dv = TrixiParticles.wrap_v(dv_ode, system, semi) + + @test isapprox(dv[:, particle], dv_expected_41[deformation], + rtol=sqrt(eps()), atol=sqrt(eps())) + end end end -end +end;