-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/issue 2966 add 7 parameter ddm cdf and ccdf #3042
base: develop
Are you sure you want to change the base?
Changes from 58 commits
26ed582
a080ab8
eee7267
8483c39
5ecf546
8f085b8
f274c18
f68fe16
10feb1f
9758d68
363eb42
2e1febf
287472c
1138571
4fc8fdb
721f181
bcc361e
d14a5d8
cda2f6e
22146d7
22bbb18
466b2d2
b5bcb06
71e7825
401c7a7
f597049
b65af9c
1b14f78
6cd53f2
d6c54c8
142e017
72617b1
601e4e1
20d9689
59ea8bc
80d4962
98756a6
aa93836
3e607c6
dba1b36
ae909ae
632a0b8
30f6c85
8a44586
13e4fbf
1bb3bd5
d756c4a
f105313
6862514
a4549f2
70bc628
44ebbcf
037cca3
c14459b
49a14eb
4293e45
979b17c
46c7a1e
e5f5048
a4d1895
572dcd4
00a930e
b9b1839
443e25b
afd9330
20cba0c
6c8688f
97ab7a9
f53cda6
cbd2c80
be67633
cc70169
b09bd1b
a280b12
bdd000e
0d05c83
44a6c2b
2489d88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,375 @@ | ||||||||||||||||||||||||||||||
#ifndef STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP | ||||||||||||||||||||||||||||||
#define STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
#include <stan/math/prim/prob/wiener4_lcdf.hpp> | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
namespace stan { | ||||||||||||||||||||||||||||||
namespace math { | ||||||||||||||||||||||||||||||
namespace internal { | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Log of probability of reaching the upper bound in diffusion process | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @tparam T_a type of boundary | ||||||||||||||||||||||||||||||
* @tparam T_w type of relative starting point | ||||||||||||||||||||||||||||||
* @tparam T_v type of drift rate | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param w_value The relative starting point | ||||||||||||||||||||||||||||||
* @param v_value The drift rate | ||||||||||||||||||||||||||||||
* @return log probability to reach the upper bound | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_a, typename T_w, typename T_v> | ||||||||||||||||||||||||||||||
inline auto wiener_prob(const T_a& a, const T_v& v_value, const T_w& w_value) { | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_a, T_w, T_v>; | ||||||||||||||||||||||||||||||
const auto v = -v_value; | ||||||||||||||||||||||||||||||
const auto w = 1 - w_value; | ||||||||||||||||||||||||||||||
if (fabs(v) == 0.0) { | ||||||||||||||||||||||||||||||
return ret_t(log1p(-w)); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto exponent = -2.0 * v * a * (1.0 - w); | ||||||||||||||||||||||||||||||
if (exponent < 0) { | ||||||||||||||||||||||||||||||
return ret_t(log1m_exp(exponent) - log_diff_exp(2 * v * a * w, exponent)); | ||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||
return ret_t(log1m_exp(-exponent) - log1m_exp(2 * v * a)); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Calculate parts of the partial derivatives for wiener_prob_grad_a and | ||||||||||||||||||||||||||||||
* wiener_prob_grad_v (on log-scale) | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @tparam T_a type of boundary | ||||||||||||||||||||||||||||||
* @tparam T_w type of relative starting point | ||||||||||||||||||||||||||||||
* @tparam T_v type of drift rate | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param w_value The relative starting point | ||||||||||||||||||||||||||||||
* @param v_value The drift rate | ||||||||||||||||||||||||||||||
* @return 'ans' term | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_a, typename T_w, typename T_v> | ||||||||||||||||||||||||||||||
inline auto wiener_prob_derivative_term(const T_a& a, const T_v& v_value, | ||||||||||||||||||||||||||||||
const T_w& w_value) noexcept { | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_a, T_w, T_v>; | ||||||||||||||||||||||||||||||
const auto exponent_m1 = log1p(-1.1 * 1.0e-8); | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does this hard coded value come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hard coded value is connected to the internal precision of this computation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it have 1e-8 precision? I'm asking where that number comes from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, I have to correct myself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going to have to tag @bob-carpenter here as I'm unsure of how we usually handle things like this. Bob the value
Sorry where in this code is the divide happening? Everything is on the log scale here so division is just turning into subtraction |
||||||||||||||||||||||||||||||
ret_t ans; | ||||||||||||||||||||||||||||||
const auto v = -v_value; | ||||||||||||||||||||||||||||||
const auto w = 1 - w_value; | ||||||||||||||||||||||||||||||
int sign_v = v < 0 ? 1 : -1; | ||||||||||||||||||||||||||||||
const auto exponent_with_1mw = sign_v * 2.0 * v * a * (1.0 - w); | ||||||||||||||||||||||||||||||
const auto exponent = (sign_v * 2 * a * v); | ||||||||||||||||||||||||||||||
const auto exponent_with_w = 2 * a * v * w; | ||||||||||||||||||||||||||||||
if (unlikely((exponent_with_1mw >= exponent_m1) | ||||||||||||||||||||||||||||||
|| ((exponent_with_w >= exponent_m1) && (sign_v == 1)) | ||||||||||||||||||||||||||||||
|| (exponent >= exponent_m1) || v == 0)) { | ||||||||||||||||||||||||||||||
return ret_t(-w); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
ret_t diff_term; | ||||||||||||||||||||||||||||||
const auto log_w = log(w); | ||||||||||||||||||||||||||||||
if (v < 0) { | ||||||||||||||||||||||||||||||
ans = LOG_TWO + exponent_with_1mw - log1m_exp(exponent_with_1mw); | ||||||||||||||||||||||||||||||
diff_term = log1m_exp(exponent_with_w) - log1m_exp(exponent); | ||||||||||||||||||||||||||||||
} else if (v > 0) { | ||||||||||||||||||||||||||||||
ans = LOG_TWO - log1m_exp(exponent_with_1mw); | ||||||||||||||||||||||||||||||
diff_term = log_diff_exp(exponent_with_1mw, exponent) - log1m_exp(exponent); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (log_w > diff_term) { | ||||||||||||||||||||||||||||||
ans += log_diff_exp(log_w, diff_term); | ||||||||||||||||||||||||||||||
ans = sign_v * exp(ans); | ||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||
ans += log_diff_exp(diff_term, log_w); | ||||||||||||||||||||||||||||||
ans = -sign_v * exp(ans); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (unlikely(!is_scal_finite(ans))) { | ||||||||||||||||||||||||||||||
return ret_t(NEGATIVE_INFTY); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
return ans; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Calculate wiener4 ccdf (natural-scale) | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param y A scalar variable; the reaction time in seconds | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param v The relative starting point | ||||||||||||||||||||||||||||||
* @param w The drift rate | ||||||||||||||||||||||||||||||
* @param wildcard This parameter space is needed for a functor. Could be | ||||||||||||||||||||||||||||||
* deleted when another solution is found | ||||||||||||||||||||||||||||||
* @param err The log error tolerance | ||||||||||||||||||||||||||||||
* @return ccdf | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_y, typename T_a, typename T_w, typename T_v, | ||||||||||||||||||||||||||||||
typename T_wildcard, typename T_err> | ||||||||||||||||||||||||||||||
inline auto wiener4_ccdf(const T_y& y, const T_a& a, const T_v& v, const T_w& w, | ||||||||||||||||||||||||||||||
T_wildcard&& wildcard = 0.0, | ||||||||||||||||||||||||||||||
T_err&& err = log(1e-12)) noexcept { | ||||||||||||||||||||||||||||||
const auto prob = exp(wiener_prob(a, v, w)); | ||||||||||||||||||||||||||||||
const auto cdf | ||||||||||||||||||||||||||||||
= internal::wiener4_distribution<GradientCalc::ON>(y, a, v, w, 0, err); | ||||||||||||||||||||||||||||||
return prob - cdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Calculate derivative of the wiener4 ccdf w.r.t. 'a' (natural-scale) | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param y A scalar variable; the reaction time in seconds | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param v The relative starting point | ||||||||||||||||||||||||||||||
* @param w The drift rate | ||||||||||||||||||||||||||||||
* @param cdf The CDF value | ||||||||||||||||||||||||||||||
* @param err The log error tolerance | ||||||||||||||||||||||||||||||
* @return Gradient w.r.t. a | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_y, typename T_a, typename T_w, typename T_v, | ||||||||||||||||||||||||||||||
typename T_cdf, typename T_err> | ||||||||||||||||||||||||||||||
inline auto wiener4_ccdf_grad_a(const T_y& y, const T_a& a, const T_v& v, | ||||||||||||||||||||||||||||||
const T_w& w, T_cdf&& cdf, | ||||||||||||||||||||||||||||||
T_err&& err = log(1e-12)) noexcept { | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_a, T_w, T_v>; | ||||||||||||||||||||||||||||||
const auto prob = wiener_prob(a, v, w); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// derivative of the wiener probability w.r.t. 'a' (on log-scale) | ||||||||||||||||||||||||||||||
auto prob_grad_a = -1 * wiener_prob_derivative_term(a, v, w) * v; | ||||||||||||||||||||||||||||||
if (!is_scal_finite(prob_grad_a)) { | ||||||||||||||||||||||||||||||
prob_grad_a = ret_t(NEGATIVE_INFTY); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto cdf_grad_a = wiener4_cdf_grad_a(y, a, v, w, cdf, err); | ||||||||||||||||||||||||||||||
return prob_grad_a * exp(prob) - cdf_grad_a; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Calculate derivative of the wiener4 ccdf w.r.t. 'v' (natural-scale) | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param y A scalar variable; the reaction time in seconds | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param v The relative starting point | ||||||||||||||||||||||||||||||
* @param w The drift rate | ||||||||||||||||||||||||||||||
* @param cdf The CDF value | ||||||||||||||||||||||||||||||
* @param err The log error tolerance | ||||||||||||||||||||||||||||||
* @return Gradient w.r.t. v | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_y, typename T_a, typename T_w, typename T_v, | ||||||||||||||||||||||||||||||
typename T_cdf, typename T_err> | ||||||||||||||||||||||||||||||
inline auto wiener4_ccdf_grad_v(const T_y& y, const T_a& a, const T_v& v, | ||||||||||||||||||||||||||||||
const T_w& w, T_cdf&& cdf, | ||||||||||||||||||||||||||||||
T_err&& err = log(1e-12)) noexcept { | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_a, T_w, T_v>; | ||||||||||||||||||||||||||||||
const auto prob | ||||||||||||||||||||||||||||||
= wiener_prob(a, v, w); // maybe hand over to this function, but then | ||||||||||||||||||||||||||||||
// wiener7_integrate_cdf has problems | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// derivative of the wiener probability w.r.t. 'v' (on log-scale) | ||||||||||||||||||||||||||||||
auto prob_grad_v = -1 * wiener_prob_derivative_term(a, v, w) * a; | ||||||||||||||||||||||||||||||
if (fabs(prob_grad_v) == INFTY) { | ||||||||||||||||||||||||||||||
prob_grad_v = ret_t(NEGATIVE_INFTY); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto cdf_grad_v = wiener4_cdf_grad_v(y, a, v, w, cdf, err); | ||||||||||||||||||||||||||||||
return prob_grad_v * exp(prob) - cdf_grad_v; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Calculate derivative of the wiener4 ccdf w.r.t. 'w' (natural-scale) | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param y A scalar variable; the reaction time in seconds | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param v The relative starting point | ||||||||||||||||||||||||||||||
* @param w The drift rate | ||||||||||||||||||||||||||||||
* @param cdf The CDF value | ||||||||||||||||||||||||||||||
* @param err The log error tolerance | ||||||||||||||||||||||||||||||
* @return Gradient w.r.t. w | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <typename T_y, typename T_a, typename T_w, typename T_v, | ||||||||||||||||||||||||||||||
typename T_cdf, typename T_err> | ||||||||||||||||||||||||||||||
inline auto wiener4_ccdf_grad_w(const T_y& y, const T_a& a, const T_v& v, | ||||||||||||||||||||||||||||||
const T_w& w, T_cdf&& cdf, | ||||||||||||||||||||||||||||||
T_err&& err = log(1e-12)) noexcept { | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_a, T_w, T_v>; | ||||||||||||||||||||||||||||||
const auto prob | ||||||||||||||||||||||||||||||
= wiener_prob(a, v, w); // maybe hand over to this function, but then | ||||||||||||||||||||||||||||||
// wiener7_integrate_cdf has problems | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// derivative of the wiener probability w.r.t. 'v' (on log-scale) | ||||||||||||||||||||||||||||||
auto prob_grad_w = ret_t(1 / w); | ||||||||||||||||||||||||||||||
if (v > 0) { | ||||||||||||||||||||||||||||||
const auto exponent = -2.0 * v * a * w; | ||||||||||||||||||||||||||||||
prob_grad_w | ||||||||||||||||||||||||||||||
= exp(LOG_TWO + exponent + log(fabs(v)) + log(a) - log1m_exp(exponent)); | ||||||||||||||||||||||||||||||
} else if (v < 0) { | ||||||||||||||||||||||||||||||
const auto exponent = 2.0 * v * a * w; | ||||||||||||||||||||||||||||||
prob_grad_w = exp(LOG_TWO + log(fabs(v)) + log(a) - log1m_exp(exponent)); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this just be rewritten as
Suggested change
Also for places with if statements like this can you make a comment on why this split has to happen? |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto cdf_grad_w = wiener4_cdf_grad_w(y, a, v, w, cdf, err); | ||||||||||||||||||||||||||||||
return prob_grad_w * exp(prob) - cdf_grad_w; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
} // namespace internal | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Log-CCDF for the 4-parameter Wiener distribution. | ||||||||||||||||||||||||||||||
* See 'wiener_full_lpdf' for more comprehensive documentation | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @tparam T_y type of scalar | ||||||||||||||||||||||||||||||
* @tparam T_a type of boundary | ||||||||||||||||||||||||||||||
* @tparam T_t0 type of non-decision time | ||||||||||||||||||||||||||||||
* @tparam T_w type of relative starting point | ||||||||||||||||||||||||||||||
* @tparam T_v type of drift rate | ||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||
* @param y A scalar variable; the reaction time in seconds | ||||||||||||||||||||||||||||||
* @param a The boundary separation | ||||||||||||||||||||||||||||||
* @param t0 The non-decision time | ||||||||||||||||||||||||||||||
* @param w The relative starting point | ||||||||||||||||||||||||||||||
* @param v The drift rate | ||||||||||||||||||||||||||||||
* @param precision_derivatives Level of precision in estimation | ||||||||||||||||||||||||||||||
* @return The log of the Wiener first passage time distribution with | ||||||||||||||||||||||||||||||
* the specified arguments for upper boundary responses | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
template <bool propto = false, typename T_y, typename T_a, typename T_t0, | ||||||||||||||||||||||||||||||
typename T_w, typename T_v> | ||||||||||||||||||||||||||||||
inline auto wiener_lccdf(const T_y& y, const T_a& a, const T_t0& t0, | ||||||||||||||||||||||||||||||
const T_w& w, const T_v& v, | ||||||||||||||||||||||||||||||
const double& precision_derivatives) { | ||||||||||||||||||||||||||||||
using T_partials_return = partials_return_t<T_y, T_a, T_t0, T_w, T_v>; | ||||||||||||||||||||||||||||||
using ret_t = return_type_t<T_y, T_a, T_t0, T_w, T_v>; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v>::value) { | ||||||||||||||||||||||||||||||
return ret_t(0.0); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>; | ||||||||||||||||||||||||||||||
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>; | ||||||||||||||||||||||||||||||
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>; | ||||||||||||||||||||||||||||||
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>; | ||||||||||||||||||||||||||||||
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
static constexpr const char* function_name = "wiener4_lccdf"; | ||||||||||||||||||||||||||||||
if (size_zero(y, a, t0, w, v)) { | ||||||||||||||||||||||||||||||
return ret_t(0.0); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
check_consistent_sizes(function_name, "Random variable", y, | ||||||||||||||||||||||||||||||
"Boundary separation", a, "Drift rate", v, | ||||||||||||||||||||||||||||||
"A-priori bias", w, "Nondecision time", t0); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
T_y_ref y_ref = y; | ||||||||||||||||||||||||||||||
T_a_ref a_ref = a; | ||||||||||||||||||||||||||||||
T_t0_ref t0_ref = t0; | ||||||||||||||||||||||||||||||
T_w_ref w_ref = w; | ||||||||||||||||||||||||||||||
T_v_ref v_ref = v; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
decltype(auto) y_val = to_ref(as_value_column_array_or_scalar(y_ref)); | ||||||||||||||||||||||||||||||
decltype(auto) a_val = to_ref(as_value_column_array_or_scalar(a_ref)); | ||||||||||||||||||||||||||||||
decltype(auto) v_val = to_ref(as_value_column_array_or_scalar(v_ref)); | ||||||||||||||||||||||||||||||
decltype(auto) w_val = to_ref(as_value_column_array_or_scalar(w_ref)); | ||||||||||||||||||||||||||||||
decltype(auto) t0_val = to_ref(as_value_column_array_or_scalar(t0_ref)); | ||||||||||||||||||||||||||||||
check_positive_finite(function_name, "Random variable", y_val); | ||||||||||||||||||||||||||||||
check_positive_finite(function_name, "Boundary separation", a_val); | ||||||||||||||||||||||||||||||
check_finite(function_name, "Drift rate", v_val); | ||||||||||||||||||||||||||||||
check_less(function_name, "A-priori bias", w_val, 1); | ||||||||||||||||||||||||||||||
check_greater(function_name, "A-priori bias", w_val, 0); | ||||||||||||||||||||||||||||||
check_nonnegative(function_name, "Nondecision time", t0_val); | ||||||||||||||||||||||||||||||
check_finite(function_name, "Nondecision time", t0_val); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const size_t N = max_size(y, a, t0, w, v); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
scalar_seq_view<T_y_ref> y_vec(y_ref); | ||||||||||||||||||||||||||||||
scalar_seq_view<T_a_ref> a_vec(a_ref); | ||||||||||||||||||||||||||||||
scalar_seq_view<T_t0_ref> t0_vec(t0_ref); | ||||||||||||||||||||||||||||||
scalar_seq_view<T_w_ref> w_vec(w_ref); | ||||||||||||||||||||||||||||||
scalar_seq_view<T_v_ref> v_vec(v_ref); | ||||||||||||||||||||||||||||||
const size_t N_y_t0 = max_size(y, t0); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
for (size_t i = 0; i < N_y_t0; ++i) { | ||||||||||||||||||||||||||||||
if (y_vec[i] <= t0_vec[i]) { | ||||||||||||||||||||||||||||||
std::stringstream msg; | ||||||||||||||||||||||||||||||
msg << ", but must be greater than nondecision time = " << t0_vec[i]; | ||||||||||||||||||||||||||||||
std::string msg_str(msg.str()); | ||||||||||||||||||||||||||||||
throw_domain_error(function_name, "Random variable", y_vec[i], " = ", | ||||||||||||||||||||||||||||||
msg_str.c_str()); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto log_error_cdf = log(1e-6); | ||||||||||||||||||||||||||||||
const auto log_error_derivative = log(precision_derivatives); | ||||||||||||||||||||||||||||||
const T_partials_return log_error_absolute = log(1e-12); | ||||||||||||||||||||||||||||||
T_partials_return lccdf = 0.0; | ||||||||||||||||||||||||||||||
auto ops_partials | ||||||||||||||||||||||||||||||
= make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
static constexpr double LOG_FOUR = LOG_TWO + LOG_TWO; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
// calculate distribution and partials | ||||||||||||||||||||||||||||||
for (size_t i = 0; i < N; i++) { | ||||||||||||||||||||||||||||||
const auto y_value = y_vec.val(i); | ||||||||||||||||||||||||||||||
const auto a_value = a_vec.val(i); | ||||||||||||||||||||||||||||||
const auto t0_value = t0_vec.val(i); | ||||||||||||||||||||||||||||||
const auto w_value = w_vec.val(i); | ||||||||||||||||||||||||||||||
const auto v_value = v_vec.val(i); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
using internal::GradientCalc; | ||||||||||||||||||||||||||||||
const T_partials_return cdf | ||||||||||||||||||||||||||||||
= internal::estimate_with_err_check<5, 0, GradientCalc::OFF, | ||||||||||||||||||||||||||||||
GradientCalc::OFF>( | ||||||||||||||||||||||||||||||
[](auto&&... args) { | ||||||||||||||||||||||||||||||
return internal::wiener4_distribution<GradientCalc::ON>(args...); | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
log_error_cdf - LOG_TWO, y_value - t0_value, a_value, v_value, | ||||||||||||||||||||||||||||||
w_value, 0.0, log_error_absolute); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto prob = exp(internal::wiener_prob(a_value, v_value, w_value)); | ||||||||||||||||||||||||||||||
const auto ccdf = prob - cdf; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
lccdf += log(ccdf); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const auto new_est_err = log(ccdf) + log_error_derivative - LOG_FOUR; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if (!is_constant_all<T_y>::value || !is_constant_all<T_t0>::value) { | ||||||||||||||||||||||||||||||
const auto deriv_y = internal::estimate_with_err_check<5, 0>( | ||||||||||||||||||||||||||||||
[](auto&&... args) { | ||||||||||||||||||||||||||||||
return internal::wiener5_density<GradientCalc::ON>(args...); | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
new_est_err, y_value - t0_value, a_value, v_value, w_value, 0.0, | ||||||||||||||||||||||||||||||
log_error_absolute); | ||||||||||||||||||||||||||||||
if (!is_constant_all<T_y>::value) { | ||||||||||||||||||||||||||||||
partials<0>(ops_partials)[i] = -deriv_y / ccdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (!is_constant_all<T_t0>::value) { | ||||||||||||||||||||||||||||||
partials<2>(ops_partials)[i] = deriv_y / ccdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (!is_constant_all<T_a>::value) { | ||||||||||||||||||||||||||||||
partials<1>(ops_partials)[i] | ||||||||||||||||||||||||||||||
= internal::estimate_with_err_check<5, 0>( | ||||||||||||||||||||||||||||||
[](auto&&... args) { | ||||||||||||||||||||||||||||||
return internal::wiener4_ccdf_grad_a(args...); | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf, | ||||||||||||||||||||||||||||||
log_error_absolute) | ||||||||||||||||||||||||||||||
/ ccdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (!is_constant_all<T_w>::value) { | ||||||||||||||||||||||||||||||
partials<3>(ops_partials)[i] | ||||||||||||||||||||||||||||||
= internal::estimate_with_err_check<5, 0>( | ||||||||||||||||||||||||||||||
[](auto&&... args) { | ||||||||||||||||||||||||||||||
return internal::wiener4_ccdf_grad_w(args...); | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf, | ||||||||||||||||||||||||||||||
log_error_absolute) | ||||||||||||||||||||||||||||||
/ ccdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
if (!is_constant_all<T_v>::value) { | ||||||||||||||||||||||||||||||
partials<4>(ops_partials)[i] | ||||||||||||||||||||||||||||||
= internal::wiener4_ccdf_grad_v(y_value - t0_value, a_value, v_value, | ||||||||||||||||||||||||||||||
w_value, cdf, log_error_absolute) | ||||||||||||||||||||||||||||||
/ ccdf; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} // for loop | ||||||||||||||||||||||||||||||
return ops_partials.build(lccdf); | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
} // namespace math | ||||||||||||||||||||||||||||||
} // namespace stan | ||||||||||||||||||||||||||||||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm kind of confused by these cutpoints. Is this because the derivative is ill defined at certain areas or is this a math optimization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a math optimization and should stay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if statement here could be more more expensive than the ops that are saved. I'd remove all of these and just keep things simple. It also just becomes really hard to read and maintain with all of these if statements in the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Results are more robust when we have this case distinction. They both shall compute the same result, but when
exponent < 0
then the upper case is more robust and whenexponent >=0
the lower case is more robust. We could insert a comment on this to make the case distinction clear.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by robust here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numerically robust.