Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make propagate more like a monadic bind by supporting stochastic triple creating functions #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/propagate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ function strip_Δs(arg; use_dual = Val(true))
end

"""
propagate(f, args...; keep_deltas = Val(false))
propagate(f, args...; keep_deltas = Val(false), keep_triples = Val(false))

Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal
and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`).
Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`.
Handles functions `f` with any input and output that is `fmap`-able by `Functors.jl`, including functions `f` that are closures
over stochastic triples.
If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`.
If continuous perturbations to `args` can cause discrete pertubations to be created within `f`,
then provide `keep_triples = Val(true)`.


This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator
overloading rules based on dispatch. For example:
Expand Down Expand Up @@ -89,6 +93,7 @@ StochasticTriple of Int64:
function propagate(f,
args...;
keep_deltas = Val(false),
keep_triples = Val(false),
provided_st_rep = nothing,
deriv = nothing)
# TODO: support kwargs to f (or just use kwfunc in macro)
Expand Down Expand Up @@ -118,15 +123,16 @@ function propagate(f,
end

primal_args = structural_map(get_value, args)
input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args)
#=
TODO: the below is dangerous is general.
It should be safe so long as f does not close over stochastic triples.
(If f is a closure, the parameters of f should be treated like any other parameters;
if they are stochastic triples and we are ignoring them, dangerous in general.)
=#
input_args = if keep_triples isa Val{true}
structural_map(x -> strip_Δs(x; use_dual = Val(false)), args)
elseif keep_deltas isa Val{true}
structural_map(strip_Δs, args)
else
primal_args
end
out = f(input_args...)
val = structural_map(value, out)
Δs1 = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), out)
# TODO: what does the only_vals do in the below and why?
Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args;
only_vals = Val{true}())
Expand All @@ -142,9 +148,12 @@ function propagate(f,
alt = f(perturbed_args...)
return structural_map((x, y) -> value(x) - y, alt, val)
end
Δs = map(map_func, Δs_coupled; out_rep = val, deriv)
Δs2 = map(map_func, Δs_coupled; out_rep = val, deriv)

# TODO: make sure all FI backends support interface needed below
new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs
new_out = structural_map(
out, Δs1, scalarize(Δs2; out_rep = val)) do leaf_out, leaf_Δs1, leaf_Δs2
leaf_Δs = combine(backendtype(st_rep), (leaf_Δs1, leaf_Δs2))
StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out),
leaf_Δs)
end
Expand Down
Loading