Skip to content

Commit

Permalink
add adapt ext
Browse files Browse the repository at this point in the history
  • Loading branch information
longemen3000 authored Jul 16, 2024
1 parent d264b62 commit 543a73f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 20 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e"

[extensions]
TransducersAdaptExt = "Adapt"
TransducersBlockArraysExt = "BlockArrays"
TransducersDataFramesExt = "DataFrames"
TransducersLazyArraysExt = "LazyArrays"
Expand Down Expand Up @@ -60,6 +62,7 @@ Tables = "0.2, 1.0"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down Expand Up @@ -88,4 +91,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["Aqua", "BlockArrays", "Compat", "DataFrames", "DataTools", "Dates", "Distributed", "Documenter", "Folds", "InteractiveUtils", "LazyArrays", "LiterateTest", "LoadAllPackages", "Maybe", "OnlineStats", "OnlineStatsBase", "PerformanceTestTools", "Pkg", "Random", "Referenceables", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Test", "TypedTables"]
test = ["Adapt", "Aqua", "BlockArrays", "Compat", "DataFrames", "DataTools", "Dates", "Distributed", "Documenter", "Folds", "InteractiveUtils", "LazyArrays", "LiterateTest", "LoadAllPackages", "Maybe", "OnlineStats", "OnlineStatsBase", "PerformanceTestTools", "Pkg", "Random", "Referenceables", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Test", "TypedTables"]
34 changes: 34 additions & 0 deletions ext/TransducersAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module TransducersAdaptExt

if isdefined(Base,:get_extension)
import Transducers
import Adapt
else
import ..Transducers
import ..Adapt
end

Adapt.adapt_structure(to, rf::R) where {R <: Transducers.Reduction} =
Transducers.Reduction(Adapt.adapt(to, Transducers.xform(rf)), Adapt.adapt(to, Transducers.inner(rf)))

Adapt.adapt_structure(to, xf::Transducers.Map) = Transducers.Map(Adapt.adapt(to, xf.f))

Adapt.adapt_structure(to, xf::Transducers.MapSplat) = Transducers.MapSplat(Adapt.adapt(to, xf.f))

Adapt.adapt_structure(to, xf::Transducers.Filter) = Transducers.Filter(Adapt.adapt(to, xf.pred))

Adapt.adapt_structure(to, xf::Transducers.GetIndex{inbounds}) where {inbounds} =
Transducers.GetIndex{inbounds}(Adapt.adapt(to, xf.array))

Adapt.adapt_structure(to, xf::Transducers.SetIndex{inbounds}) where {inbounds} =
Transducers.SetIndex{inbounds}(Adapt.adapt(to, xf.array))

Adapt.adapt_structure(to, xf::Transducers.ReducePartitionBy) = Transducers.ReducePartitionBy(
Adapt.adapt(to, xf.f),
Adapt.adapt(to, xf.rf),
Adapt.adapt(to, xf.init),
)
end #module



2 changes: 2 additions & 0 deletions src/Transducers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ const OSNonZeroNObsError = ArgumentError(

if !isdefined(Base,:get_extension)
using Requires
using Adapt
include("../ext/TransducersAdaptExt.jl")
function __init__()
@require BlockArrays="8e7c35d0-a365-5155-bbbb-fb81a777f24e" include("../ext/TransducersBlockArraysExt.jl")
@require LazyArrays="5078a376-72f3-5289-bfd5-ec5146d43c02" include("../ext/TransducersLazyArraysExt.jl")
Expand Down
2 changes: 0 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,6 @@ Transducer(rf::Reduction) =
# `Reduction` to `AbstractReduction`.
Reduction(::IdentityTransducer, inner) = ensurerf(inner)

Adapt.adapt_structure(to, rf::R) where {R <: Reduction} =
Reduction(Adapt.adapt(to, xform(rf)), Adapt.adapt(to, inner(rf)))

"""
Transducers.R_{X}
Expand Down
12 changes: 0 additions & 12 deletions src/library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ OutputSize(::Type{<:Map}) = SizeStable()
isexpansive(::Map) = false
@inline next(rf::R_{Map}, result, input) = next(inner(rf), result, xform(rf).f(input))

Adapt.adapt_structure(to, xf::Map) = Map(Adapt.adapt(to, xf.f))

"""
MapSplat(f)
Expand Down Expand Up @@ -83,8 +81,6 @@ isexpansive(::MapSplat) = false
@inline next(rf::R_{MapSplat}, result, input) =
next(inner(rf), result, xform(rf).f(input...))

Adapt.adapt_structure(to, xf::MapSplat) = MapSplat(Adapt.adapt(to, xf.f))

# https://clojure.github.io/clojure/clojure.core-api.html#clojure.core/replace
# https://clojuredocs.org/clojure.core/replace
"""
Expand Down Expand Up @@ -290,8 +286,6 @@ end
@inline next(rf::R_{Filter}, result, input) =
xform(rf).pred(input) ? next(inner(rf), result, input) : result

Adapt.adapt_structure(to, xf::Filter) = Filter(Adapt.adapt(to, xf.pred))

"""
NotA(T)
Expand Down Expand Up @@ -1562,9 +1556,6 @@ Base.:(==)(xf1::GetIndex{inbounds,A},
xf2::GetIndex{inbounds,A}) where {inbounds,A} =
xf1.array == xf2.array

Adapt.adapt_structure(to, xf::GetIndex{inbounds}) where {inbounds} =
GetIndex{inbounds}(Adapt.adapt(to, xf.array))

"""
SetIndex(array)
SetIndex{inbounds}(array)
Expand Down Expand Up @@ -1610,9 +1601,6 @@ Base.:(==)(xf1::SetIndex{inbounds,A},
xf2::SetIndex{inbounds,A}) where {inbounds,A} =
xf1.array == xf2.array

Adapt.adapt_structure(to, xf::SetIndex{inbounds}) where {inbounds} =
SetIndex{inbounds}(Adapt.adapt(to, xf.array))

"""
Inject(iterator)
Expand Down
5 changes: 0 additions & 5 deletions src/partitionby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ struct ReducePartitionBy{F,RF,Init} <: Transducer
end
ReducePartitionBy(f, rf) = ReducePartitionBy(f, rf, Init)

Adapt.adapt_structure(to, xf::ReducePartitionBy) = ReducePartitionBy(
Adapt.adapt(to, xf.f),
Adapt.adapt(to, xf.rf),
Adapt.adapt(to, xf.init),
)

struct PartitionChunk{K,V}
kr::K
Expand Down

0 comments on commit 543a73f

Please sign in to comment.