-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
De-threadid
ing parallel loops
#588
Comments
Thank you very much! I didn't think of the pattern of handling the range splitting manually and then using per-chunk rather than per-thread buffers. I am not sure of the performance implications of re-allocating nthreads arrays in the inner loop, I'd rather avoid it if possible. I think doing the manual transform for that single very performance-critical part and using floops everywhere else is a good compromise. |
Also, do I understand correctly that
is equivalent to
|
Oops, sorry, the It is equivalent to chunks = collect(enumerate(Iterators.partition(xs, n)))
@threads for (i, chunk) in chunks
for x in chunk
f(x, i)
end
end if |
Thanks, I updated it in the discourse. One thing that confused me is that |
So in our application we actually have two nested loops; this is currently in a PR but the final version will look like T = promote_type(eltype(basis), real(eltype(ψ[1])))
ρ = similar(ψ[1], T, (basis.fft_size..., basis.model.n_spin_components))
ρ .= 0
ψnk_real = zeros(complex(T), basis.fft_size)
for ik = 1:length(basis.kpoints)
kpt = basis.kpoints[ik]
for n = 1:size(ψ[ik], 2)
ψnk = @views ψ[ik][:, n]
G_to_r!(ψnk_real, basis, kpt, ψnk)
ρ[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* abs2.(ψnk_real)
end
end Right now we don't really parallelize nested loops (we only parallelize the inner one, and use MPI on the outer one), but it'd be nice to potentially parallelize at both levels. Of course we can flatten the nested loop into a flat one, and apply the techniques you showed above, but that doesn't really compose well (this is a simple example but in other places we have this kind of nested structures with more involved code). I guess there's not really any good way to do composability here without allocating nthreads()^2 arrays... (the inner loop can also be BLAS/FFTW calls, but that's for later.) |
FWIW, FLoops.jl supports nested loop like @floop ThreadedEx(nestlevel = Val(2), basesize = bs) for ik = 1:length(basis.kpoints), n = 1:size(ψ[ik], 2)
kpt = basis.kpoints[ik]
ψnk = @views ψ[ik][:, n]
# It's unclear if the following statements are data-race free:
G_to_r!(ψnk_real, basis, kpt, ψnk)
ρ[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* abs2.(ψnk_real)
end where Alternatively, you can also do @floop ThreadedEx(basesize = bs1) for ik = 1:length(basis.kpoints)
kpt = basis.kpoints[ik]
@floop ThreadedEx(basesize = bs2) for n = 1:size(ψ[ik], 2)
...
end
end with some appropriate But I don't know exactly how |
What I meant by composable is something like for k=1:Nk
diagonalize(x -> do_stuff(x,k))
end
function diagonalize(A)
# solve A(x) = lambda x iteratively
end
function do_stuff(x,k)
for n=1:N
...
end
end Then if we were to do the above spawn trick or floops on each of the for loops we'd have to allocate nthreads()^2 temporaries, or do I misunderstand? Whereas currently with the threadid() approach we just allocate nthreads() temporaries as global variables.
For my future reference, a single 16x16x16 FFT is 23microseconds. Even on very coarse examples we are much above this (silicon Ecut 15 is 27) so overhead is not an issue (might not be the case elsewhere).
We did do some profiling before optimizing, and I think I recall preallocations turning out to be important (of course we'd love to get rid of them). That said, it was 2 years ago, things might have improved since then and we should do that again. I also saw reports that allocations were very slow in a multithreaded context. In the case of the nthreads()^2 though I'd be more concerned about crashing the memory (the temporary psink_real is much bigger than the data it operates on, and storing nthreads()^2 of them is probably too much).
Cool, looking forward to trying it out! |
If you want to pre-allocate arrays and disallow overlapping uses in nested parallelism, you can use a channel to manage buffers nbuffers = Threads.nthreads() # it doesn't have to be nthreads()
buffers = Channel{???}(nbuffers)
for _ in 1:nbuffers
put!(buffers, allocate_buffer())
end then, using the idiom in the OP n = cld(length(xs), Threads.nthreads())
@sync for (i, chunk) in enumerate(Iterators.partition(xs, n))
Threads.@spawn begin
b = take!(buffers)
try
for x in chunk
f!(x, b)
end
finally
put!(buffers, b)
end
end
end If |
Thanks! That kind of low-level primitives always scares me but perhaps there's no real reason for that, esp because that would maybe free us to use higher-level primitives (like floops). That looks like a good way to go! |
The |
Closing this for now: we've implemented the manual chunking with spawn for now; if we need more parallelism we'll do the Channel thing. Thank you very much @tkf ! |
Continuing https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042/20, here's how I'd implement the code you quoted (IIUC).
Mechanical transformation that almost always works
First of all, here's a trick you can use almost always. If you have this pattern
you can mechanically convert this to
This is very likely correct if the loop body
f
only usesthreadid()
with arrays allocated only for this parallel loop (e.g., pre-1.3 reduction pattern).Array reduction with
@reduce acc .+= x
I think
DFTK.jl/src/densities.jl
Lines 24 to 36 in a1791f6
can be re-written using
This requires FLoops 0.1.12 or above.
Pre-allocated scratch space and TimerOutput
DFTK.jl/src/terms/Hamiltonian.jl
Lines 124 to 148 in a1791f6
If you are OK with allocating about
nthreads
arrays every time executing this code,ψ_real = H.scratch.ψ_reals[tid]
can simply be re-written as@init ψ_real = similar(H.scratch.ψ_reals[1])
or something equivalent on the RHS. If you must reuseH.scratch.ψ_reals
, the easiest approach probably is to use the mechanical transformation I noted above.I am not sure why you are throwing away the timer info on the non-primary thread, but, if you want to merge all of them, you can do
in the loop body and then
outside to merge all the timer outputs, provided that
merge!
onTimerOutput
acts like the method onDict
.The text was updated successfully, but these errors were encountered: