diff --git a/stan/math/prim/prob/multi_normal_cholesky_lpdf.hpp b/stan/math/prim/prob/multi_normal_cholesky_lpdf.hpp index bda51f702b6..47fecc0682e 100644 --- a/stan/math/prim/prob/multi_normal_cholesky_lpdf.hpp +++ b/stan/math/prim/prob/multi_normal_cholesky_lpdf.hpp @@ -53,6 +53,9 @@ return_type_t multi_normal_cholesky_lpdf( using T_partials_return = partials_return_t; using matrix_partials_t = Eigen::Matrix; + using vector_partials_t = Eigen::Matrix; + using row_vector_partials_t + = Eigen::Matrix; using T_y_ref = ref_type_t; using T_mu_ref = ref_type_t; using T_L_ref = ref_type_t; @@ -119,28 +122,39 @@ return_type_t multi_normal_cholesky_lpdf( } if (include_summand::value) { - Eigen::Matrix - y_val_minus_mu_val(size_y, size_vec); + row_vector_partials_t half(size_vec); + vector_partials_t y_val_minus_mu_val(size_vec); + vector_partials_t scaled_diff(size_vec); + matrix_partials_t L_val = value_of(L_ref); + + T_partials_return sum_lp_vec(0.0); for (size_t i = 0; i < size_vec; i++) { decltype(auto) y_val = as_value_column_vector_or_scalar(y_vec[i]); decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_vec[i]); - y_val_minus_mu_val.col(i) = y_val - mu_val; + y_val_minus_mu_val = y_val - mu_val; + half = mdivide_left_tri(L_val, y_val_minus_mu_val) + .transpose(); + scaled_diff = mdivide_right_tri(half, L_val).transpose(); + + sum_lp_vec += dot_self(half); + + if (!is_constant_all::value) { + partials_vec<0>(ops_partials)[i] += -scaled_diff; + } + if (!is_constant_all::value) { + partials_vec<1>(ops_partials)[i] += scaled_diff; + } + if (!is_constant::value) { + partials_vec<2>(ops_partials)[i] += scaled_diff * half; + } } - matrix_partials_t half; - matrix_partials_t scaled_diff; + logp += -0.5 * sum_lp_vec; // If the covariance is not autodiff, we can avoid computing a matrix // inverse if (is_constant::value) { - matrix_partials_t L_val = value_of(L_ref); - - half = mdivide_left_tri(L_val, y_val_minus_mu_val) - .transpose(); - - scaled_diff = mdivide_right_tri(half, L_val).transpose(); - if (include_summand::value) { logp -= sum(log(L_val.diagonal())) * size_vec; } @@ -148,30 +162,9 @@ return_type_t multi_normal_cholesky_lpdf( matrix_partials_t inv_L_val = mdivide_left_tri(value_of(L_ref)); - half = (inv_L_val.template triangularView() - * y_val_minus_mu_val) - .transpose(); - - scaled_diff = (half * inv_L_val.template triangularView()) - .transpose(); - logp += sum(log(inv_L_val.diagonal())) * size_vec; - partials<2>(ops_partials) -= size_vec * inv_L_val.transpose(); - for (size_t i = 0; i < size_vec; i++) { - partials_vec<2>(ops_partials)[i] += scaled_diff.col(i) * half.row(i); - } - } - - logp -= 0.5 * sum(columns_dot_self(half)); - - for (size_t i = 0; i < size_vec; i++) { - if (!is_constant_all::value) { - partials_vec<0>(ops_partials)[i] -= scaled_diff.col(i); - } - if (!is_constant_all::value) { - partials_vec<1>(ops_partials)[i] += scaled_diff.col(i); - } + partials<2>(ops_partials) -= size_vec * inv_L_val.transpose(); } }