diff --git a/Project.toml b/Project.toml index 80ecb93c4..5e6f13b7e 100644 --- a/Project.toml +++ b/Project.toml @@ -34,8 +34,9 @@ LoadAllPackages = "b37bcd2d-1570-475d-a8c6-9b4fae6d0ba9" MicroCollections = "128add7d-3638-4c79-886c-908ea0c25c34" PerformanceTestTools = "dc46b164-d16f-48ec-a853-60448fc869fe" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestFunctionRunner = "792026f5-ac9a-4a19-adcb-47b0ce2deb5d" [targets] -test = ["Distributed", "Documenter", "Future", "LiterateTest", "LoadAllPackages", "MicroCollections", "PerformanceTestTools", "Random", "Test", "TestFunctionRunner"] +test = ["Distributed", "Documenter", "Future", "LiterateTest", "LoadAllPackages", "MicroCollections", "PerformanceTestTools", "Random", "StaticArrays", "Test", "TestFunctionRunner"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 3ec5e6066..81b83cbfa 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -296,6 +296,12 @@ git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.14" +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "2884859916598f974858ff01df7dfc6c708dd895" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.3.3" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/docs/Project.toml b/docs/Project.toml index 1d6f4fea5..795bac979 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,3 +8,4 @@ LiterateTest = "d77d25b0-90d3-4a16-b10a-412a9d48f625" LoadAllPackages = "b37bcd2d-1570-475d-a8c6-9b4fae6d0ba9" MicroCollections = "128add7d-3638-4c79-886c-908ea0c25c34" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/docs/src/reference/api.md b/docs/src/reference/api.md index aca97df20..6acbfc63d 100644 --- a/docs/src/reference/api.md +++ b/docs/src/reference/api.md @@ -12,12 +12,24 @@ FLoops.@floop FLoops.@reduce ``` +## `@combine` + +```@docs +FLoops.@combine +``` + ## `@init` ```@docs FLoops.@init ``` +## `@completebasecase` + +```@docs +FLoops.@completebasecase +``` + ## [`SequentialEx`, `ThreadedEx` and `DistributedEx` executors](@id executor) An *executor* controls how a given `@floop` is executed. FLoops.jl re-exports diff --git a/docs/src/tutorials/parallel.md b/docs/src/tutorials/parallel.md index d92d30d25..bb84b1d39 100644 --- a/docs/src/tutorials/parallel.md +++ b/docs/src/tutorials/parallel.md @@ -1,17 +1,23 @@ # [Parallel loops](@id tutorials-parallel) `@floop` supports parallel loops not only for side-effect (as in -`Threads.@threads`) but also for complex reductions using the optional -`@reduce` syntax. +`Threads.@threads`) but also for complex reductions using the `@combine` and +`@reduce` macros. + +If you already know how `mapreduce` works, [Relation to `mapreduce`](@ref +floop-and-mapreduce) may be the best first step for understanding the `@floop` +syntax. + +```@contents +Pages = ["parallel.md"] +Depth = 3 +``` + +!!! note + This tutorial can be read without reading the subsections with "Advanced:" + prefix. -`@floop` is useful even without `@reduce` because it supports multiple -[executors](@ref tutorials-executor) for selecting specific execution -mechanisms without rewriting your code. For example, -[FoldsThreads.jl](https://github.com/JuliaFolds/FoldsThreads.jl) provides -additional rich set of thread-based executors from which you can choose -an appropriate executor to maximize the performance of your program. -[FoldsCUDA.jl](https://github.com/JuliaFolds/FoldsCUDA.jl) provides an -executor for GPU. FLoops.jl also provide a simple distributed executor. +## Independent execution For in-place update operations (i.e., `Threads.@threads`-like operations), you can use `@floop ThreadedEx() for`: @@ -33,8 +39,10 @@ julia> floop_map!(x -> x + 1, zeros(3), 1:3) 4.0 ``` -For a parallel algorithm that requires reductions, you can use -`@reduce(acc op= x)` syntax: +## Reduction using `@reduce acc ⊗= x` syntax + +For a parallel algorithm that requires reductions, you can use `@reduce acc ⊗= +x` syntax: ```jldoctest julia> using FLoops @@ -42,18 +50,267 @@ julia> using FLoops julia> @floop for (x, y) in zip(1:3, 1:2:6) a = x + y b = x - y - @reduce(s += a, t += b) + @reduce s += a + @reduce t += b end (s, t) (15, -3) ``` -With `@reduce`, the default executor is `ThreadedEx`. +## Combining explicit sequential reduction results using `@combine` + +FLoops.jl parallelizes a given loop by dividing the iteration space into +*basecases* and then execute the serial reduction on each basecase. These +sub-results are combined using the function specified by `@combine` or +`@reduce` syntax. + +!!! note + + Exactly how the executor schedules the basecases and the computation for + combining them depends on the type (e.g., threads/GPU/distributed) and the + scheduling options. However, the loop using `@floop` works with all of them + provided that `@combine` and `@reduce` define associative function. + +```jldoctest +julia> using FLoops + +julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end]; + +julia> @floop begin + @init hist = zeros(Int, 10) # (1) initialization + for char in pidigits # (2) basecase + n = char - '0' + hist[n+1] += 1 + end + @combine hist .= hist .+ _ # (3) combine + # Or, use a short hand notation: + # @combine hist .= _ + end + hist +10-element Vector{Int64}: + 31559 + 31597 + 31392 + 31712 + 31407 + 31835 + 31530 + 31807 + 31469 + 31345 +``` + +!!! note + + Above example uses string to show that FLoops.jl (and also other JuliaFolds + packages) support strings. But this is of course not a very good format for + playing with the digits of pi. + +_Conceptually_, this produces a program that acts like (but is more optimized +than) the following code: + +```julia +# `chunks` is prepared such that: +@assert pidigits == reduce(vcat, chunks) +# i.e., pidigits == [chunks[1]; chunks[2]; ...; chunks[end]] + +hists = Vector{Any}(undef, length(chunks)) +@sync for (i, pidigitsᵢ) in enumerate(chunks) + @spawn begin + local hist = zeros(Int, 10) # (1) initialization + for char in pidigitsᵢ # (2) basecase + n = char - '0' + hist[n+1] += 1 + end + hists[i] = hist # "sub-solution" of this basecase + end +end +hist = hists[1] +for hist′ in hists[2:end] + hist .= hist .+ hist′ # (3) combine the sub-solutions +end +``` + +(1) The basecase-local accumulators are initialized using the [`@init`](@ref) +statements. + +(2) Each basecase loop is executed with its own local accumulators. + +(3) The sub-solutions `hists` are combined using the expression specified by +`@combine`. In the above pseudo code, given the expression `hist .= hist .+ _` +(or equivalently `hist .+= _`), the symbol `hist` is substituted by the +sub-solution `hist` of the first basecase and the symbol `_` is substituted by +the sub-solution `hist` of the second basecase. Evaluation of this expression +produces a sub-solution `hist` combining the first and the second basecases. +The sub-solution of the third and later basecases are combined into `hist` using +the same procedure. + +In general, the expression + +```julia +@combine acc = op(acc, _) +``` + +indicates that a sub-solution `acc` computed for a certain subset of the input +collection (e.g., `pidigits` in the example) is combined with the sub-solution +`acc_right` using -## Initialization with `@reduce(acc = op(init, x))` syntax +```julia +acc = op(acc, acc_right) +``` + +The binary function/operator `op` must be +[associative](https://en.wikipedia.org/wiki/Associative_property). However, +`op` does not have to be side-effect-free. In fact, if invoking in-place +`op` on the sob-solutions does not cause thread safety issues, there is no +problem in using in-place mutation. For example, the above usage of `@combine +hist .= hist .+ _` is correct because `hist` is created for each basecase; i.e., +no combine step can mutate the vector `hist` while other combine step tries to +read from or write to the same vector. + +!!! warning + All three pieces of the above `@floop begin ... end` code (i.e., (1) `@init + ...`, (2) `for`-loop body, and (3) `@combine ...`) _may_ (and likely will) + be executed concurrently. Thus, **they must be written in such a way that + concurrent execution in _arbitrary number_ of tasks is correct** (e.g., no + data race is possible). In particular, the above pseudo code is inaccurate + in that it executes the `@combine` expression serially. This is typically + not guaranteed by the [executor](@ref tutorials-executor) provided by + JuliaFolds. + +!!! note + The combine steps of the above pseudo code is different from how most of the + executors in JuliaFolds execute FLoops.jl. Typically, the combine steps are + executed in parallel; i.e., they use a more tree-like fashion to provide a + greater amount of + [_parallelism_](https://www.cprogramming.com/parallelism.html). + +Only the variables available "after" the `for` loop (but not the variables local +to the loop body) can be used as the arguments to `@combine`. Typically, it +means the symbols specified by `@init`. However, it is possible to introduce +new variables for `@combine` by placing the code introducing new variables after +the `for` loop (see [Executing code at the end of basecase](@ref +simple-completebasecase)). Note also that `@init`'ed variables do not have to +be `@combine`d. For example, `@init` can be used for allocating local buffer +for intermediate computation (See: [Local buffers using `@init`](@ref +local-buffer)). + +## Advanced: Understanding `@combine` in terms of `mapreduce` + +Alternatively, a more concise way to understand `@floop` and `@combine` is to +conceptualized it as a lowering to a call to `mapreduce`: + +```julia +function basecase(pidigitsᵢ) + local hist = zeros(Int, 10) # (1) initialization + for char in pidigitsᵢ # (2) basecase + n = char - '0' + hist[n+1] += 1 + end + return hist +end + +function combine!(hist, hist′) + hist .= hist .+ hist′ # (3) combine the sub-solutions + return hist +end + +hist = mapreduce(basecase, combine!, chunks) +``` + +where `mapreduce` is a parallel implementation of `Base.mapreduce` (e.g., +`Folds.mapreduce`). Although this picture still does not reflect the actual +internal of FLoops.jl (and Transducers.jl), this is a much more accurate mental +model than the pseudo code above. + +## Advanced: Unifying sequential and cross-basecase reductions -Use `acc = op(init, x)` to specify that the identity element for the -binary function `op` is `init`: +To accumulate numbers into a vector, we can use `push!` in the basecase and +combine the vectors from different basecases using `append!`. + +```jldoctest +julia> using FLoops + +julia> @floop begin + @init odds = Int[] + @init evens = Int[] + for x in 1:5 + if isodd(x) + push!(odds, x) + else + push!(evens, x) + end + end + @combine odds = append!(odds, _) + @combine evens = append!(evens, _) + end + (odds, evens) +([1, 3, 5], [2, 4]) +``` + +Although this code works without an issue, it is redundant to use `push!` and +`append!` in this example. Since `push!(xs, x)` and `append!(xs, [x])` are +equivalent, these functions are quite similar. The intermediate value `[x]` is +referred to as a *singleton solution* because it is the value that would be used +if the input collection to the `for` loop contain only one item. + +Indeed, once we have the singleton solution, we can simplify the above code by +using the syntax + + @reduce acc = op(init, input) + +The expression `init` in the first argument position specifies how to initialize +the reduction result `acc`. The expression `input` specifies the value defined +in the loop body which is accumulated into the reduction result `acc`. The +current accumulation state `acc` is updated by + + acc = op(acc, input) + +Using this notation, the above code can be simplified to + +```jldoctest +julia> using FLoops + +julia> @floop for x in 1:5 + ys = [x] # "singleton solution" + if isodd(x) + @reduce odds = append!(Int[], ys) + else + @reduce evens = append!(Int[], ys) + end + end + (odds, evens) +([1, 3, 5], [2, 4]) +``` + +```jldoctest +julia> let + odds = Int[] # \___ The expression in the first argument is + evens = Int[] # / used for the initialization + for x in 1:5 + ys = [x] + if isodd(x) + odds = append!(odds, ys) + # ----- + # LHS `odds` inserted to the first argument + else + evens = append!(evens, ys) + # ----- + # LHS `evens` inserted to the first argument + end + end + (odds, evens) + end +([1, 3, 5], [2, 4]) +``` + +### Handling unknown element types + +In the above code, we assumed that we know the type of the elements that are +accumulated into a vector. However, when writing generic code, it is often +impossible to know the element types in advance. We can use BangBang.jl and +MicroCollections.jl to create a vector of items with unknown types in such a way +that the compiler can optimize very well. ```jldoctest julia> using FLoops @@ -65,16 +322,16 @@ julia> using MicroCollections # for `EmptyVector` and `SingletonVector` julia> @floop for x in 1:5 ys = SingletonVector(x) if isodd(x) - @reduce(odds = append!!(EmptyVector(), ys)) + @reduce odds = append!!(EmptyVector(), ys) else - @reduce(evens = append!!(EmptyVector(), ys)) + @reduce evens = append!!(EmptyVector(), ys) end end (odds, evens) ([1, 3, 5], [2, 4]) ``` -## Initialization with `@reduce(acc = init op x)` syntax +### Initialization with `@reduce(acc = init op x)` syntax When `op` is a binary operator, the infix syntax `acc = init op x` can also be used: @@ -99,7 +356,90 @@ first argument is replaced by the corresponding LHS, i.e., `odds = append!!(odds, ys)` and `s = s + a`, are evaluated for the bulk of the loop. -## Complex reduction with `@reduce() do` syntax +## [Local buffers using `@init`](@id local-buffer) + +`@init` can be used without the reduction syntaxes. It is useful when some +basecase-local buffers are required (for avoiding data races): + +```jldoctest +julia> using FLoops + +julia> ys = zeros(5); + +julia> @floop begin + @init buffer = zeros(100) + for i in 1:5 + buffer .= sin.(i .* range(0, pi; length = length(buffer))) + ys[i] = sum(buffer) + end + end +``` + +!!! note + + `@init` can also be used inside of the `for` loop body with the `@floop for` + syntax as in + + ```julia + @floop for i in 1:5 + @init buffer = zeros(100) + buffer .= sin.(i .* range(0, pi; length = length(buffer))) + ys[i] = sum(buffer) + end + ``` + + However, `@floop begin ... end` syntax is recommended. + +## [Executing code at the end of basecase](@id simple-completebasecase) + +On GPU, the reduction result must be an immutable value (and not contain any +GC-manged objects). This is often not a problem since Julia ecosystem has a +rich set of tooling for programming with immutable values. For example, we can +use [`StaticArrays.SVector`](https://github.com/JuliaArrays/StaticArrays.jl) for +a histogram with a small number of bins. However, indexing update on `SVector` +is very inefficient compared to `StaticArrays.MVector`. Thus, it is better to +execute the basecase reduction using `MVector` while the cross-basecase +reduction uses `SVector`. The transformation from `MVector` to `SVector` can be +done by inserting the code after the `for` loop and before the `@combine` +expression. + +```jldoctest +julia> using FLoops + +julia> using StaticArrays + +julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end]; + +julia> @floop begin + @init buf = zero(MVector{10,Int32}) + for char in pidigits + n = char - '0' + buf[n+1] += 1 + end + hist = SVector(buf) + @combine hist .+= _ + end + hist +10-element SVector{10, Int32} with indices SOneTo(10): + 31559 + 31597 + 31392 + 31712 + 31407 + 31835 + 31530 + 31807 + 31469 + 31345 +``` + +!!! note + + To run this on GPU, specific executor library like FoldsCUDA.jl has to be + used. Furthermore, `pidigits` has to be transformed into a GPU-compatible + format (e.g., `CuVector{Int8}`). + +## Advanced: Complex reduction with `@reduce() do` syntax For more complex reduction, use `@reduce() do` syntax: @@ -120,8 +460,6 @@ julia> @floop for (i, v) in pairs([0, 1, 3, 2]), (j, w) in pairs([3, 1, 5]) (5, 1, 3) ``` -### How to read a loop with `@reduce() do` syntax - When reading code with `@reduce() do`, a quick way to understand it is to mentally comment out the line with `@reduce() do` and the corresponding `end`. To get a full picture, move the initialization @@ -149,29 +487,55 @@ julia> let ``` This exact transformation is used for defining the sequential -basecase. Consecutive basecases are combined using the code in the -`do` block body. +basecase. + +Consecutive basecases are combined using the code in the `do` block body. That +is to say, the accumulation result `acc = (dmax, imax, jmax)` from a basecase +and the accumulation result `acc_right = (dmax, imax, jmax)` from then next +basecase are combined using the following function + +```julia +function combine(acc, acc_right) + (dmax, imax, jmax) = acc # left variables are bound to left sub-solutions + (d, i, j) = acc_right # right variables are bound to right sub-solutions + if isless(dmax, d) + dmax = d + imax = i + jmax = j + end + acc = (dmax, imax, jmax) + return acc +end +``` -## Control flow syntaxes +Note that variables left to `;` and the variables right to `;` in the original +`@reduce() do` syntax are grouped into the left argument `acc` and the right +argument `acc_right`, respectively. This is why the `@reduce() do` syntax uses +the nonstandard delimiter `;` for separating the arguments. That is to say, +`@reduce() do` syntax "transposes" (or "unzips") the arguments to clarify the +correspondence of the left and the right arguments. In general, the expression + +```julia +@reduce() do (acc₁; x₁), (acc₂; x₂), ..., (accₙ; xₙ) + $expression_updates_accs +end +``` -Control flow syntaxes such as `continue`, `break`, `return`, and -`@goto` work with parallel loops: +generates the combine function -```jldoctest -julia> using FLoops - -julia> @floop for x in 1:10 - y = 2x - @reduce() do (s; y) - s = y - end - x == 3 && break - end - s -6 +```julia +function combine((acc₁, acc₂, ..., accₙ), (x₁, x₂, ..., xₙ)) + $expression_updates_accs + return (acc₁, acc₂, ..., accₙ) +end ``` -`@reduce` can be used multiple times in a loop body +(Aside: This also clarifies why `@reduce() do` doesn't use the standard argument +ordering `@reduce() do (acc₁, acc₂, ..., accₙ), (x₁, x₂, ..., xₙ)`. From this +expression, it is very hard to tell `accᵢ` corresponds to `xᵢ`.) + +Like other `@reduce` expressions, `@reduce() do` syntax can be used multiple +times in a loop body: ```jldoctest julia> using FLoops @@ -195,6 +559,84 @@ julia> @floop for (i, v) in pairs([0, 1, 3, 2]) ((6, 3), (0, 1)) ``` +Since the variables left to `;` (i.e., `ymax`, `imax`, `ymin`, and `imin` in the +above example) are the "output" variables, they must be unique (otherwise, the +computation result is not available outside the loop). However, the variables +right to `;` (i.e., `y` and `i` in the above example) do not have to be unique +because multiple reductions can be computed using the same intermediate +computation done in the loop body. + +Similar to `@reduce() do` syntax, there is `@combine() do` syntax. This is +useful when it is more straightforward to use different code for the basecase +and combine steps. + +```jldoctest +julia> using FLoops + +julia> function maybe_zero_extend_right!(hist, n) + l = length(hist) + if l < n + resize!(hist, n) + fill!(view(hist, l+1:n), 0) + end + end; + +julia> function count_positive_ints(ints, ex = ThreadedEx()) + @floop ex begin + @init hist = Int[] + for n in ints + n > 0 || continue # filter out non-positive integers + maybe_zero_extend_right!(hist, n) + @inbounds hist[n] += 1 + end + @combine() do (hist; other) + n = length(other) + maybe_zero_extend_right!(hist, n) + @views hist[1:n] .+= other + end + end + return hist + end; + +julia> count_positive_ints([7, 5, 3, 3, 8, 6, 0, 6, 5, 2, 6, 6, 5, 0, 8, 3, 4, 2, 5, 2]) +8-element Vector{Int64}: + 0 + 3 + 3 + 1 + 4 + 4 + 1 + 2 +``` + +## Control flow syntaxes + +Control flow syntaxes such as `continue`, `break`, `return`, and `@goto` work +with parallel loops, provided that they are used outside the `@reduce` syntax: + +```jldoctest +julia> using FLoops + +julia> function firstmatch(p, xs; ex = ThreadedEx()) + @floop ex for ix in pairs(xs) + _i, x = ix + if p(x) + @reduce() do (found = nothing; ix) + found = ix + end + break + end + end + return found # the *first* pair `i => x` s.t. `p(x)` + end; + +julia> firstmatch(==(42), 1:10) # finds nothing + +julia> firstmatch(isodd, [0, 2, 1, 1, 1]) +3 => 1 +``` + ## [Executors](@id tutorials-executor) `@floop` takes optional executor argument to specify an execution strategies @@ -230,3 +672,37 @@ JuliaFolds provides additional executors: rich set of thread-based executors. * [FoldsCUDA.jl](https://github.com/JuliaFolds/FoldsCUDA.jl) provides `CUDAEx` for executing the parallel loop on GPU. + +## [Advanced: Relation to `mapreduce`](@id floop-and-mapreduce) + +If you know are familar with functional style data parallel API and already know +`mapreduce(f, op, xs; init)` works, it is worth noting that `@floop` is, *as a +very rough approximation*, a way to invoke `acc = mapreduce(f, op, xs; init)` +with a custom syntax + +```julia +@floop for x in xs + y = f(x) + @reduce acc = op(init, y) +end +``` + +or + +```julia +@floop begin + @init acc = init + for x in xs + y = f(x) + acc = op(acc, y) + end + @combine acc = op(acc, _) +end +``` + +However, as explained above, `@floop` supports various constructs that are not +directly supported by `mapreduce`. To fully cover the semantics of `@floop` in +a functional manner, the extended reduction ("fold") protocol of +[Transducers.jl](https://github.com/JuliaFolds/Transducers.jl) is required. In +fact, FLoops.jl is simply a syntax sugar for invoking the reductions defined in +Transducers.jl. diff --git a/src/FLoops.jl b/src/FLoops.jl index 4358b886c..048326a39 100644 --- a/src/FLoops.jl +++ b/src/FLoops.jl @@ -20,7 +20,16 @@ module FLoops doc end FLoops -export @floop, @init, @reduce, DistributedEx, SequentialEx, ThreadedEx +#! format: off +export @floop, + @init, + @combine, + @reduce, + @completebasecase, + DistributedEx, + SequentialEx, + ThreadedEx +#! format: on using BangBang.Extras: broadcast_inplace!! using BangBang: materialize!!, push!! @@ -62,6 +71,7 @@ using Transducers: transduce, unreduced, whencombine, + whencompletebasecase, wheninit if isdefined(JuliaVariables, :solve!) @@ -82,6 +92,7 @@ end include("utils.jl") include("macro.jl") include("reduce.jl") +include("combine.jl") include("scratchspace.jl") include("checkboxes.jl") diff --git a/src/combine.jl b/src/combine.jl new file mode 100644 index 000000000..1170cef02 --- /dev/null +++ b/src/combine.jl @@ -0,0 +1,406 @@ +""" + @combine acc ⊗= _ + @combine acc = acc ⊗ _ + @combine acc = op(acc, _) + @combine acc .⊗= _ + @combine acc .= acc .⊗ _ + @combine acc .= op.(acc, _) + @combine() do (acc₁; acc₁′), ..., (accₙ; accₙ′) + ... + end + +Declare how accumulators from two basecases are combined. Unlike `@reduce`, the +reduction for the basecase is not defined by this macro. +""" +macro combine(ex) + :(throw($(CombineOpSpec(Any[ex])))) +end + +macro combine(ex1, ex2, exprs...) + error(""" + Unlike `@reduce`, `@combine` only supports single expression. + Use: + @combine a += _ + @combine b += _ + Instead of: + @combine(a += _, b += _) + """) +end + +struct CombineOpSpec <: OpSpec + args::Vector{Any} + visible::Vector{Symbol} +end + +CombineOpSpec(args::Vector{Any}) = CombineOpSpec(args, Symbol[]) +macroname(::CombineOpSpec) = Symbol("@combine") + +# Without a macro like `@completebasecase`, it'd be confusing to have an +# expression such as +# +# @floop begin +# ... +# for x in xs +# ... # executed in parallel loop body +# end +# for y in ys # executed in completebasecase hook +# ... +# end +# ... +# end +# +# i.e., two similar loops have drastically different semantics. The difference +# can be clarified by using the syntax: +# +# @floop begin +# ... +# for x in xs +# ... # executed in parallel loop body +# end +# @completebasecase begin +# for y in ys # executed in completebasecase hook +# ... +# end +# end +# ... +# end +""" + @completebasecase ex + +Evaluate expression `ex` at the end of each basecase. The expression `ex` can +only refer to the variables declared by `@init`. + +`@completebasecase` can be omitted if `ex` does not contain a `for` loop. + +# Examples +```jldoctest +julia> using FLoops + +julia> pidigits = string(BigFloat(π; precision = 2^20))[3:end]; + +julia> @floop begin + @init hist = zeros(Int, 10) + for c in pidigits + i = c - '0' + 1 + hist[i] += 1 + end + @completebasecase begin + j = 0 + y = 0 + for (i, x) in pairs(hist) # pretending we don't have `argmax` + if x > y + j = i + y = x + end + end + peaks = [j] + nchunks = [sum(hist)] + end + @combine hist .+= _ + @combine peaks = append!(peaks, _) + @combine nchunks = append!(nchunks, _) + end +``` +""" +macro completebasecase(ex) + ex = Expr(:block, __source__, ex) + :(throw($(CompleteBasecaseOp(ex)))) +end + +struct CompleteBasecaseOp + ex::Expr +end + +function extract_spec(ex) + @match ex begin + Expr(:call, throw′, spec::ReduceOpSpec) => spec + Expr(:call, throw′, spec::CombineOpSpec) => spec + Expr(:call, throw′, spec::InitSpec) => spec + Expr(:call, throw′, spec::CompleteBasecaseOp) => spec + _ => nothing + end +end + +isa_spec(::Type{T}) where {T} = x -> extract_spec(x) isa T + +function combine_parallel_loop(ctx::MacroContext, ex::Expr, simd, executor = nothing) + iterspec, body, ansvar, pre, post = destructure_loop_pre_post( + ex; + multiple_loop_note = string( + " Wrap the expressions after the first loop (parallel loop) with", + " `@completebasecase`.", + ), + ) + @assert ansvar == :_ + + parallel_loop_ex = @match iterspec begin + Expr(:block, loop_axes...) => begin + rf_arg, coll = transform_multi_loop(loop_axes) + as_parallel_combine_loop(ctx, pre, post, rf_arg, coll, body, simd, executor) + end + Expr(:(=), rf_arg, coll) => begin + as_parallel_combine_loop(ctx, pre, post, rf_arg, coll, body, simd, executor) + end + end + return parallel_loop_ex +end + +function as_parallel_combine_loop( + ctx::MacroContext, + pre::Vector, + post::Vector, + rf_arg, + coll, + body0::Expr, + simd, + executor, +) + @assert simd in (false, true, :ivdep) + foreach(disalow_raw_for_loop_without_completebasecase, post) + + init_exprs = [] + all_rf_accs = [] + + for ex in pre + ex isa LineNumberNode && continue + spec = extract_spec(ex) + spec isa InitSpec || error("non-`@init` expression before `for` loop: ", ex) + + accs = spec.lhs + push!(all_rf_accs, accs) + + # The expression from `@init $initializer`; sets `accs`: + push!(init_exprs, spec.expr) + end + # Accumulator for the basecase reduction; i.e., the first argument to the + # `next` reducing step function: + base_accs = mapcat(identity, all_rf_accs) + + firstcombine = something(findfirst(isa_spec(CombineOpSpec), post), lastindex(post) + 1) + + completebasecase_exprs = post[firstindex(post):firstcombine-1] + if any(isa_spec(CompleteBasecaseOp), completebasecase_exprs) + # If `CompleteBasecaseOp` is used, this must be the only expression: + let exprs = [x for x in completebasecase_exprs if !(x isa LineNumberNode)], + spec = extract_spec(exprs[1]) + + if spec isa CompleteBasecaseOp && length(exprs) == 1 + completebasecase_exprs = Any[spec.ex] + elseif all(isa_spec(CompleteBasecaseOp), exprs) + error("Only one `@completebasecase` can be used. got:\n", join(exprs, "\n")) + else + error( + "`@completebasecase` cannot be mixed with other expressions.", + " Put everything in `@completebasecase begin ... end`. got:\n", + join(exprs, "\n"), + ) + end + end + end + + left_accs = [] + right_accs = [] + combine_bodies = [] + for i in firstcombine:lastindex(post) + ex = post[i] + ex isa LineNumberNode && continue + spec = extract_spec(ex) + if !(spec isa CombineOpSpec) + error( + "non-`@combine` expressions must be placed between `for` loop and the", + " first `@combine` expression: ", + spec, + ) + end + left, right, combine_body = process_combine_op_spec(spec) + append!(left_accs, left) + append!(right_accs, right) + push!(combine_bodies, combine_body) + end + + # TODO: handle `@reduce` in the loop body + @gensym result + # See also: `asfoldl`: + body, info = transform_loop_body(body0, base_accs) + pack_state = info.pack_state + unpack_state = :(($(left_accs...),) = $result) + gotos = gotos_for(info.external_labels, unpack_state, result) + base_accs_declarations = [:(local $v) for v in base_accs] + left_accs_declarations = [:(local $v) for v in left_accs] + right_accs_declarations = [:(local $v) for v in right_accs] + + @gensym( + oninit_function, + reducing_function, + completebasecase_function, + combine_function, + context_function, + ) + return quote + $Base.@inline function $oninit_function() + $(base_accs_declarations...) + return tuple($(init_exprs...)) + end + $Base.@inline function $reducing_function(($(base_accs...),), $rf_arg) + $(base_accs_declarations...) + $body + return ($(base_accs...),) + end + function $completebasecase_function(($(base_accs...),)) + $(base_accs_declarations...) + $(left_accs_declarations...) + $(completebasecase_exprs...) + return ($(left_accs...),) + end + $combine_function(_, b::$(Union{Goto,Return})) = b + function $combine_function(($(left_accs...),), ($(right_accs...),)) + $(left_accs_declarations...) + $(right_accs_declarations...) + $(combine_bodies...) + return ($(left_accs...),) + end + $context_function() = (; ctx = $ctx, id = $(QuoteNode(gensym(:floop_id)))) + $_verify_no_boxes($reducing_function, $context_function) + $result = $_fold( + $wheninit( + $oninit_function, + $whencompletebasecase( + $completebasecase_function, + $whencombine($combine_function, $reducing_function), + ), + ), + $coll, + $executor, + $(Val(simd)), + ) + $result isa $Return && return $result.value + $(gotos...) + $unpack_state + nothing + end +end + +function process_combine_op_spec( + spec::CombineOpSpec, +)::NamedTuple{(:left, :right, :combine_body)} + @assert length(spec.args) == 1 + ex, = spec.args::Vector{Any} + + if is_function(ex) + # handle: @combine() do ... + rf_ex = ex + # rf_ex = :(((left1; right1), ..., (leftN; rightN)) -> rf_body) + left, inits, right = analyze_rf_args(rf_ex.args[1]) + if inits !== nothing + error("`@combine() do` syntax does not support initalization; got:\n", spec) + end + combine_body = rf_ex.args[2] + return (; left = left, right = right, combine_body = combine_body) + end + + if is_dot_update(ex) + # handle: @combine left .⊗= _ + op = Symbol(String(ex.head)[2:end-1]) + lhs = ex.args[1] + if ex.args[2] !== :_ + error( + "expected expression of form `@combine lhs .⊗= _`; the rhs is not `_`: ", + ex, + ) + end + true + elseif isexpr(ex, :(.=), 2) + if !is_dotcall(ex.args[2], 2) + error( + "`@combine lhs .= rhs` syntax requires a binary dot call", + " (e.g., `a .+ b` or `f.(a, b)`) on the rhs; got:\n", + ex, + ) + end + # handle: @combine left .= op.(_, _) + lhs, rhs = ex.args + if isexpr(rhs, :call, 3) + dotop, l, r = rhs.args + str = String(dotop) + @assert startswith(str, ".") + op = Symbol(str[2:end]) + else + @assert rhs.head == :. && + length(rhs.args) == 2 && + isexpr(rhs.args[2], :tuple, 2) + op = rhs.args[1] + l, r = rhs.args[2].args + end + if l === r === :_ # allowing :_ on both hand side; TODO: maybe not? + elseif !((l, r) == (lhs, :_) || (r, l) == (lhs, :_)) + error( + "`@combine lhs .= rhs` syntax expects that the arguments", + " of the rhs are lhs and `_`; got: ", + ex, + ) + end + true + else + false + end && begin + left = Any[lhs] + rightarg = lhs isa Symbol ? gensym(Symbol(lhs, :_right)) : gensym(:right) + right = Any[rightarg] + broadcast_inplace!! = GlobalRef(@__MODULE__, :broadcast_inplace!!) + combine_body = :($lhs = $broadcast_inplace!!($op, $lhs, $rightarg)) + # ^- mutate-or-widen version of `$lhs .= ($op).($lhs, _)` + # TODO: use accurate line number from `@combine` + return (; left = left, right = right, combine_body = combine_body) + end + + if is_rebinding_update(ex) + # handle: @combine left ⊗= _ + op = Symbol(String(ex.head)[1:end-1]) + lhs, rhs = ex.args + if rhs !== :_ + error( + "expected expression of form `@combine lhs ⊗= _`; the rhs is not `_`: ", + ex, + ) + end + elseif isexpr(ex, :(=), 2) && isexpr(ex.args[2], :call, 3) + # handle: @combine left = op(_, _) + lhs, rhs = ex.args + op, l, r = rhs.args + if l === r === :_ # allowing :_ on both hand side; TODO: maybe not? + elseif !((l, r) == (lhs, :_) || (r, l) == (lhs, :_)) + error( + "`@combine lhs = rhs` syntax expects that the arguments", + " of the rhs are lhs and `_`; got: ", + ex, + ) + end + else + error("unsupported: ", spec) + end + left = Any[lhs] + rightarg = lhs isa Symbol ? gensym(Symbol(lhs, :_right)) : gensym(:right) + right = Any[rightarg] + combine_body = :($lhs = $op($lhs, $rightarg)) + # TODO: use accurate line number from `@combine` + return (; left = left, right = right, combine_body = combine_body) +end + +function disalow_raw_for_loop_without_completebasecase(@nospecialize(ex)) + ex isa Expr || return + extract_spec(ex) === nothing || return + _disalow_raw_for_loop(ex) +end + +function _disalow_raw_for_loop(@nospecialize(ex)) + ex isa Expr || return + if isexpr(ex, :for) + error( + "`@floop begin ... end` can only contain one `for` loop.", + " Use `@completebasecase begin ... end` to wrap the code after the parallel", + " loop, including the `for` loop. Got:\n", + ex, + ) + end + foreach(_disalow_raw_for_loop, ex.args) +end diff --git a/src/macro.jl b/src/macro.jl index 62f1ffe81..649d57fc9 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -41,6 +41,9 @@ macro floop(ex) ex, simd = remove_at_simd(__module__, ex) exx = macroexpand(__module__, ex) isexpr(exx, :for) && return esc(floop_parallel(ctx, exx, simd)) + isexpr(exx, :block) && + any(x -> extract_spec(x) isa Union{CombineOpSpec,InitSpec}, exx.args) && + return esc(combine_parallel_loop(ctx, exx, simd)) esc(floop(exx, simd)) end @@ -48,7 +51,11 @@ macro floop(executor, ex) ctx = MacroContext(__source__, __module__) ex, simd = remove_at_simd(__module__, ex) exx = macroexpand(__module__, ex) - esc(floop_parallel(ctx, exx, simd, executor)) + if isexpr(ex, :for, 2) + esc(floop_parallel(ctx, exx, simd, executor)) + else + esc(combine_parallel_loop(ctx, exx, simd, executor)) + end end struct Return{T} @@ -62,7 +69,7 @@ end Goto{label}(acc::T) where {label,T} = Goto{label,T}(acc) gotoexpr(label::Symbol) = :($Goto{$(QuoteNode(label))}) -function floop(ex, simd) +function destructure_loop_pre_post(ex; multiple_loop_note = "") pre = post = Union{}[] ansvar = :_ if isexpr(ex, :for) @@ -76,7 +83,13 @@ function floop(ex, simd) pre = args[1:i-1] post = args[i+1:end] if find_first_for_loop(post) !== nothing - throw(ArgumentError("Multiple top-level `for` loop found in:\n$ex")) + msg = string( + "Multiple top-level `for` loops found.", + multiple_loop_note, + " Given expression:\n", + ex, + ) + throw(ArgumentError(msg)) end else throw(ArgumentError("Unsupported expression:\n$ex")) @@ -84,6 +97,11 @@ function floop(ex, simd) if ansvar !== :_ post = vcat(post, ansvar) end + return loops, body, ansvar, pre, post +end + +function floop(ex, simd) + loops, body, ansvar, pre, post = destructure_loop_pre_post(ex) pre = vcat(pre, something(EXTRA_STATE_VARIABLES[], Union{}[])) init_vars = mapcat(assigned_vars, pre) diff --git a/src/reduce.jl b/src/reduce.jl index 807a7efe0..9afad270d 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -86,12 +86,15 @@ macro reduce(args...) end # TODO: detect free variables in `do` blocks -struct ReduceOpSpec +abstract type OpSpec end + +struct ReduceOpSpec <: OpSpec args::Vector{Any} visible::Vector{Symbol} end ReduceOpSpec(args::Vector{Any}) = ReduceOpSpec(args, Symbol[]) +macroname(::ReduceOpSpec) = Symbol("@reduce") """ @init begin @@ -852,13 +855,20 @@ struct _FLoopInit end transduce(IdentityTransducer(), rf, DefaultInit, coll, maybe_set_simd(exc, simd)), ) -function Base.showerror(io::IO, opspecs::ReduceOpSpec) - print(io, "`@reduce(") - join(io, opspecs.args, ", ") - print(io, ")` used outside `@floop`") +function Base.print(io::IO, spec::OpSpec) + # TODO: print as `do` block + print(io, macroname(spec), "(") + join(io, spec.args, ", ") + print(io, ")") +end + +Base.show(io::IO, ::MIME"text/plain", spec::OpSpec) = print(io, spec) + +function Base.showerror(io::IO, spec::OpSpec) + print(io, "`", spec, "` used outside `@floop`") end function Base.showerror(io::IO, spec::InitSpec) ex = spec.expr - print(io, "`@init", ex, "` used outside `@floop`") + print(io, "`@init ", ex, "` used outside `@floop`") end diff --git a/test/FLoopsTests/Project.toml b/test/FLoopsTests/Project.toml index 57e8dc3a7..5d591a0a8 100644 --- a/test/FLoopsTests/Project.toml +++ b/test/FLoopsTests/Project.toml @@ -14,5 +14,6 @@ MicroCollections = "128add7d-3638-4c79-886c-908ea0c25c34" PerformanceTestTools = "dc46b164-d16f-48ec-a853-60448fc869fe" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" diff --git a/test/FLoopsTests/src/FLoopsTests.jl b/test/FLoopsTests/src/FLoopsTests.jl index f9cfeea92..4b8af7663 100644 --- a/test/FLoopsTests/src/FLoopsTests.jl +++ b/test/FLoopsTests/src/FLoopsTests.jl @@ -2,6 +2,8 @@ module FLoopsTests using Test +include("utils.jl") + for file in sort([file for file in readdir(@__DIR__) if match(r"^test_.*\.jl$", file) !== nothing]) include(file) diff --git a/test/FLoopsTests/src/test_combine.jl b/test/FLoopsTests/src/test_combine.jl new file mode 100644 index 000000000..1dfa263b0 --- /dev/null +++ b/test/FLoopsTests/src/test_combine.jl @@ -0,0 +1,172 @@ +module TestCombine + +using FLoops +using MicroCollections +using StaticArrays +using Test + +using ..Utils: @macroexpand_error + +function count_ints_two_pass(indices, ex = nothing) + l, h = extrema(indices) + n = h - l + 1 + @floop ex begin + @init hist = zeros(Int, n) + for i in indices + hist[i-l+1] += 1 + end + @combine hist .+= _ + end + return hist +end + +valueof(::Val{x}) where {x} = x + +function count_ints_two_pass2(indices, ex = nothing) + l, h = extrema(indices) + n = Val(h - l + 1) + @floop ex begin + @init hist = zero(MVector{valueof(n),Int32}) + for i in indices + hist[i-l+1] += 1 + end + @completebasecase hist = SVector(hist) + @combine hist .+= _ + end + return hist +end + +function test_count_ints_two_pass() + @testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)] + @test count_ints_two_pass(1:3, ex) == [1, 1, 1] + @test count_ints_two_pass([1, 2, 4, 1], ex) == [2, 1, 0, 1] + @test count_ints_two_pass2(1:3, ex) == [1, 1, 1] + @test count_ints_two_pass2([1, 2, 4, 1], ex) == [2, 1, 0, 1] + end +end + +function count_ints4(ints; nbins::Val{n} = Val(10), ex = nothing) where {n} + @floop ex begin + @init b1 = zero(MVector{n,Int32}) + @init b2 = zero(MVector{n,Int32}) + @init b3 = zero(MVector{n,Int32}) + @init b4 = zero(MVector{n,Int32}) + for (i1, i2, i3, i4) in ints + @inbounds b1[max(1, min(i1, n))] += 1 + @inbounds b2[max(1, min(i2, n))] += 1 + @inbounds b3[max(1, min(i3, n))] += 1 + @inbounds b4[max(1, min(i4, n))] += 1 + end + h1 = SVector(b1) + h2 = SVector(b2) + h3 = SVector(b3) + h4 = SVector(b4) + + @combine h1 .+= _ + @combine h2 .= _ .+ _ + @combine h3 += _ + @combine h4 = _ + _ + end + return (h1, h2, h3, h4) +end + +function test_count_ints4() + @testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)] + @test count_ints4(zip(1:3, 2:4, 3:5, 4:6); ex = ex) == ( + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ) + end +end + +function count_positive_ints(ints; ex = nothing) + @floop ex begin + @init hist = Int[] + + for i in ints + n = length(hist) + if i > n + resize!(hist, i) + hist[n+1:end] .= 0 + end + @inbounds hist[max(1, i)] += 1 + end + + @combine() do (hist; hist2) + n = length(hist) + m = length(hist2) + if m > n + n, m = m, n + hist, hist2 = hist2, hist + end + hist[1:m] .+= hist2 + end + end + return hist +end + +function test_count_positive_ints() + @testset "$(repr(ex))" for ex in [SequentialEx(), nothing, ThreadedEx(basesize = 1)] + @test count_positive_ints(1:3; ex = ex) == [1, 1, 1] + @test count_positive_ints([1, 2, 4, 1]; ex = ex) == [2, 1, 0, 1] + end +end + +function test_error_one_for_loop1() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + for y in ys + end + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("Wrap the expressions after the first loop", msg) +end + +function test_error_one_for_loop2() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + function f() + for y in ys + end + end + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("can only contain one `for` loop", msg) +end + +function test_error_mixing_plain_expr_and_completebasecase() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + @completebasecase for y in ys + end + f(ys) + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("cannot be mixed with other expressions", msg) +end + +function test_error_two_completebasecase_macro_calls() + err = @macroexpand_error @floop begin + @init a = nothing + for x in xs + end + @completebasecase nothing + @completebasecase nothing + end + @test err isa Exception + msg = sprint(showerror, err) + @test occursin("Only one `@completebasecase` can be used", msg) +end + +end # module diff --git a/test/FLoopsTests/src/test_doctest.jl b/test/FLoopsTests/src/test_doctest.jl index ed904ab57..9f169f769 100644 --- a/test/FLoopsTests/src/test_doctest.jl +++ b/test/FLoopsTests/src/test_doctest.jl @@ -13,6 +13,11 @@ function test(; skip = true) @info "Skipping doctests on Julia $VERSION." @test_skip nothing return + elseif VERSION ≥ v"1.8-" + # https://github.com/JuliaArrays/StaticArrays.jl/pull/989 + @info "Skipping doctests on Julia $VERSION." + @test_skip nothing + return end end PerformanceTestTools.@include_foreach("__test_doctest.jl", [[]]) diff --git a/test/FLoopsTests/src/utils.jl b/test/FLoopsTests/src/utils.jl new file mode 100644 index 000000000..b55e8e672 --- /dev/null +++ b/test/FLoopsTests/src/utils.jl @@ -0,0 +1,17 @@ +module Utils + +struct NoError end + +macro macroexpand_error(ex) + @gensym err + quote + try + $Base.@eval $Base.@macroexpand $ex + $NoError() + catch $err + $err + end + end |> esc +end + +end # module diff --git a/test/environments/main/Manifest.toml b/test/environments/main/Manifest.toml index fc5598e7a..4a641c70e 100644 --- a/test/environments/main/Manifest.toml +++ b/test/environments/main/Manifest.toml @@ -115,7 +115,7 @@ uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" version = "0.1.1" [[FLoopsTests]] -deps = ["BangBang", "Documenter", "FLoops", "FLoopsBase", "Future", "LiterateTest", "MicroCollections", "PerformanceTestTools", "Random", "Serialization", "Test", "Transducers"] +deps = ["BangBang", "Documenter", "FLoops", "FLoopsBase", "Future", "LiterateTest", "MicroCollections", "PerformanceTestTools", "Random", "Serialization", "StaticArrays", "Test", "Transducers"] path = "../../FLoopsTests" uuid = "1c45e723-db3a-42a1-a87c-e007e67bfd87" version = "0.1.0" @@ -302,6 +302,12 @@ git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.14" +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "2884859916598f974858ff01df7dfc6c708dd895" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.3.3" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"