Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make propagate more like a monadic bind by supporting stochastic triple creating functions #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gaurav-arya
Copy link
Owner

@gaurav-arya gaurav-arya commented Aug 3, 2024

x-ref #128. @GuusAvis example:

using StochasticAD
using Distributions

function f(value_1, value_2, rand_var)
    if value_1 < value_2
        return (value_1 + rand(rand_var), value_2)
    else
        return (value_1, value_2 + rand(rand_var))
    end
end

propagate_f(value_1, value_2, rand_var) = StochasticAD.propagate((v1, v2) -> f(v1, v2, rand_var), value_1, value_2)

f(value_1::StochasticTriple, value_2, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1::StochasticTriple, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)

function g(p)
    rand_var = Bernoulli(p)
    value_1 = 0
    value_2 = 2
    for i in 1:10
        value_1, value_2 = f(value_1, value_2, rand_var)
    end
    return value_1, value_2
end

@show g(0.5)
@show mean((sum(g(0.6)) - sum(g(0.5))) / 0.1 for i in 1:1000) # 9.59
@show mean(derivative_estimate(p -> sum(g(p)), 0.5) for i in 1:100) # 8.84

@GuusAvis let me know if you have any issues, and if things work out adding the above as a test to triples.jl would be most welcome:)

@GuusAvis
Copy link

GuusAvis commented Aug 5, 2024

Many thanks for this PR @gaurav-arya, I think the code is working great now (see also my comment in the issue where I shared a test).

The code you pushed broke an existing test, I managed to solve one problem but there appear to me more.. In one of the tests you are passing keep_deltas = Val{test_deltas} instead of keep_deltas = Val(test_deltas). I think this problem is very related to #126. I opened a PR into this branch to fix the issue, see #133. There are other issues also that I'm happy to help with but am unsure about (I'm not sure I entirely understand the dynamic of keep_deltas and keep_triples, was keep_triples required because of the other changes in this PR?).

Moreover, I have added another PR (#134) into this branch to add more tests. I have added a statistical test based on your code here, and another simpler test of adding two numbers together (more or less testing the original example I gave in #128 ). Let me know if you think this makes sense and feel free to suggest or make any chances.

@gaurav-arya
Copy link
Owner Author

gaurav-arya commented Aug 7, 2024

Hey! Thank you so much for testing it:) I am going to have to wait until after my thesis deadline Friday midnight, but I'll circle circle back then! If you have an MWE of the 600x slowdown (which is definitely too slow) that would be great.

Edit: making a @profview of the MWE can also be very informative

@GuusAvis
Copy link

GuusAvis commented Aug 10, 2024

No worries!

MWE of the 600x slowdown is the step game as also posted in #128. Let me reproduce it here completely for clarify and consistency.

The code that I use for the step game:

using StochasticAD, Distributions

"""
    update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)

Update the values of by sampling the random variables.

Note: this function may return stochastic triples even if `value_1` and `value_2` are normal
numbers, as the samples taken from `rand_var_1` and `rand_var_2` may be stochastic triples.
This is currently not correctly handled by `StochasticAD.propagate`.
"""
function update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
    within_tolerance(value_1, value_2, tolerance) && return value_1, value_2
    if value_1 < value_2
        value_1 += rand(rand_var_1)
    else
        value_2 += rand(rand_var_2)
    end
    value_1, value_2
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2, rand_var_2, tolerance)
    f = v1 -> update_values(v1, rand_var_1, value_2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function update_values(value_1, rand_var_1, value_2::StochasticAD.StochasticTriple,
        rand_var_2, tolerance)
    f = v2 -> update_values(value_1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2::StochasticAD.StochasticTriple, rand_var_2, tolerance)
    f = (v1, v2) -> update_values(v1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

"""
    within_tolerance(value_1, value_2, tolerance)

Check if the values satisfy the tolerance level.
"""
within_tolerance(value_1, value_2, tolerance) = abs(value_1 - value_2) < tolerance
function within_tolerance(value_1::StochasticAD.StochasticTriple, value_2, tolerance)
    StochasticAD.propagate(x -> within_tolerance(x, value_2, tolerance), value_1)
end
function within_tolerance(value_1, value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate(x -> within_tolerance(value_1, x, tolerance), value_2)
end
function within_tolerance(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate((x, y) -> within_tolerance(x, y, tolerance), value_1, value_2)
end

"""
    istrue(x::Bool) =  x

Check if a number is unambiguously `true` in a way that works for stochastic triples.

If any of the branches of the stochastic triple are `false`, then the result is `false`.
"""
istrue(x::Bool) =  x
function istrue(x::StochasticAD.StochasticTriple)
    primary = isone(StochasticAD.value(x))
    perts = StochasticAD.alltrue(iszero, x.Δs)
    primary && perts
end

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    finished = false
    while !istrue(finished)
        value_1, value_2 = update_values(value_1, rand_var_1, value_2, rand_var_2,
            tolerance)
        finished = within_tolerance(value_1, value_2, tolerance)
    end
    value_1, value_2
end

It looks like a bit much code for an MWE but a lot of it is just calling propagate.

Now, let's do some benchmarking! First without derivatives:

rand_var_1 = Geometric(0.1)
rand_var_2 = Geometric(0.1)
tolerance = 5.

@elapsed step_game(rand_var_1, rand_var_2, tolerance)  # ~ 5E-6 seconds
@allocated step_game(rand_var_1, rand_var_2, tolerance)  # always 32
@elapsed [step_game(rand_var_1, rand_var_2, tolerance) for _ in 1:1E7]  # ~ 1.25 seconds

We see that we can run 1E7 samples in about a second, not too bad. Allocations also looking good I think. Now if we make one of the random variables produce stochastic triples:

rand_var_1_triple = Geometric(stochastic_triple(0.01))
rand_var_2 = Geometric(0.01)
tolerance = 5.

@elapsed step_game(rand_var_1_triple, rand_var_2, tolerance)  # ~ 2E-3 seconds
@allocated step_game(rand_var_1_triple, rand_var_2, tolerance)  # random, ~ 4E5
@elapsed [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E3]  # ~ 1.15 seconds
@profview [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E3]  # lots of propagate

Now we can only run 1E3 samples in a second, suggesting an even bigger slowdown then I had initially reported. Also the number of allocations is all over the place. There are a lot of allocations (tens of thousands for a single run), and moreover it appears to be a random number while it was constant for the primal evaluation above.

I had indeed also used @profview, it suggests (unsurprisingly) that most time is spent inside of propagates. Underneath that are mostly mapping functions (structural_map etc) and it was not very easy for me to interpret it directly.
For your reference, I will share the flamegraph I produced here (had to zip it because github doesn't like html, also attached screenshot for convenience).
step_game_profile.zip
Screenshot from 2024-08-09 20-20-30

@GuusAvis
Copy link

@gaurav-arya Did you have a chance to look at the slowdown yet?

@gaurav-arya
Copy link
Owner Author

Hi @GuusAvis -- not yet. Thank you for the ping -- I'll take a look this weekend!

@gaurav-arya
Copy link
Owner Author

Partial debug (made JET.jl happy on the code by tweaking a few things, and set the geometric parameter to 0.1 in both cases for consistency. but still a large performance gap, and profview is still pointing to runtime dispatches that I somehow cannot see on JET.jl or Cthulhu.jl...)

using StochasticAD, Distributions

"""
    update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)

Update the values of by sampling the random variables.

Note: this function may return stochastic triples even if `value_1` and `value_2` are normal
numbers, as the samples taken from `rand_var_1` and `rand_var_2` may be stochastic triples.
This is currently not correctly handled by `StochasticAD.propagate`.
"""
function update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
    value_1_ret = value_1 + zero(rand(rand_var_1))
    value_2_ret = value_2 + zero(rand(rand_var_2))
    within_tolerance(value_1, value_2, tolerance) && return value_1_ret, value_2_ret
    if value_1 < value_2
        value_1_ret += rand(rand_var_1) 
    else
        value_2_ret += rand(rand_var_2) 
    end
    return value_1_ret, value_2_ret
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2, rand_var_2, tolerance)
    f = v1 -> update_values(v1, rand_var_1, value_2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function update_values(value_1, rand_var_1, value_2::StochasticAD.StochasticTriple,
        rand_var_2, tolerance)
    f = v2 -> update_values(value_1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2::StochasticAD.StochasticTriple, rand_var_2, tolerance)
    f = (v1, v2) -> update_values(v1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

"""
    within_tolerance(value_1, value_2, tolerance)

Check if the values satisfy the tolerance level.
"""
within_tolerance(value_1, value_2, tolerance) = abs(value_1 - value_2) < tolerance
function within_tolerance(value_1::StochasticAD.StochasticTriple, value_2, tolerance)
    StochasticAD.propagate(x -> within_tolerance(x, value_2, tolerance), value_1)
end
function within_tolerance(value_1, value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate(x -> within_tolerance(value_1, x, tolerance), value_2)
end
function within_tolerance(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate((x, y) -> within_tolerance(x, y, tolerance), value_1, value_2)
end

"""
    istrue(x::Bool) =  x

Check if a number is unambiguously `true` in a way that works for stochastic triples.

If any of the branches of the stochastic triple are `false`, then the result is `false`.
"""
istrue(x::Bool) =  x
function istrue(x::StochasticAD.StochasticTriple)
    primary = isone(StochasticAD.value(x))
    perts = StochasticAD.alltrue(iszero, x.Δs)
    primary && perts
end

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    finished = if (value_1 isa StochasticTriple) || (value_2 isa StochasticTriple)
        StochasticAD.propagate(_ -> false, value_1 + value_2)
    else
        false
    end
    while !istrue(finished)
        @time value_1, value_2 = update_values(value_1, rand_var_1, value_2, rand_var_2,
            tolerance)
        finished = within_tolerance(value_1, value_2, tolerance)
    end
    value_1, value_2
end

##

rand_var_1 = Geometric(0.1)
rand_var_2 = Geometric(0.1)
tolerance = 5.

step_game(rand_var_1, rand_var_2, tolerance) 
@elapsed step_game(rand_var_1, rand_var_2, tolerance)  # ~ 5E-6 seconds
@allocated step_game(rand_var_1, rand_var_2, tolerance)  # always 32
@elapsed [step_game(rand_var_1, rand_var_2, tolerance) for _ in 1:1E6]  # ~ 0.4 seconds

##

rand_var_1_triple = Geometric(stochastic_triple(0.1; backend = PrunedFIsBackend()))
rand_var_2 = Geometric(0.1)
tolerance = 5.

step_game(rand_var_1_triple, rand_var_2, tolerance) 
@elapsed step_game(rand_var_1_triple, rand_var_2, tolerance)  # ~ 2E-4 seconds
@allocated step_game(rand_var_1_triple, rand_var_2, tolerance)  # random, ~ 4E5
@elapsed [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E4]  # ~ 1.3 seconds
@profview [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E4]


##

let value_1 = rand(rand_var_1), value_2 = rand(rand_var_2)
    @time [update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance) for i in 1:10000000]
    nothing
end

let value_1 = rand(rand_var_1_triple), value_2 = rand(rand_var_2)
    @time [update_values(value_1, rand_var_1_triple, value_2, rand_var_2, tolerance) for i in 1:100000]
    nothing
end

let value_1 = rand(rand_var_1_triple), value_2 = rand(rand_var_2)
    @profview [update_values(value_1, rand_var_1_triple, value_2, rand_var_2, tolerance) for i in 1:100000]
    nothing
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants