Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add beta_neg_binomial_lccdf #3114

Merged
135 changes: 80 additions & 55 deletions stan/math/prim/fun/grad_F32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ namespace math {
* This power-series representation converges for all gradients
* under the same conditions as the 3F2 function itself.
*
* @tparam T type of arguments and result
* @tparam grad_a1 boolean indicating if gradient with respect to a1 is required
* @tparam grad_a2 boolean indicating if gradient with respect to a2 is required
* @tparam grad_a3 boolean indicating if gradient with respect to a3 is required
* @tparam grad_b1 boolean indicating if gradient with respect to b1 is required
* @tparam grad_b2 boolean indicating if gradient with respect to b2 is required
* @tparam grad_z boolean indicating if gradient with respect to z is required
* @tparam T1 a scalar type
* @tparam T2 a scalar type
* @tparam T3 a scalar type
* @tparam T4 a scalar type
* @tparam T5 a scalar type
* @tparam T6 a scalar type
* @tparam T7 a scalar type
* @tparam T8 a scalar type
* @param[out] g g pointer to array of six values of type T, result.
* @param[in] a1 a1 see generalized hypergeometric function definition.
* @param[in] a2 a2 see generalized hypergeometric function definition.
Expand All @@ -35,84 +48,96 @@ namespace math {
* @param[in] precision precision of the infinite sum
* @param[in] max_steps number of steps to take
*/
template <typename T>
void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
const T& b2, const T& z, const T& precision = 1e-6,
template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
typename T1, typename T2, typename T3, typename T4, typename T5,
typename T6, typename T7, typename T8 = double>
void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
const T6& b2, const T7& z, const T8& precision = 1e-6,
int max_steps = 1e5) {
check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);

using std::exp;
using std::fabs;
using std::log;

for (int i = 0; i < 6; ++i) {
g[i] = 0.0;
}

T log_g_old[6];
T1 log_g_old[6];
for (auto& x : log_g_old) {
x = NEGATIVE_INFTY;
}

T log_t_old = 0.0;
T log_t_new = 0.0;
T1 log_t_old = 0.0;
T1 log_t_new = 0.0;

T log_z = log(z);
T7 log_z = log(z);

double log_t_new_sign = 1.0;
double log_t_old_sign = 1.0;
double log_g_old_sign[6];
T1 log_t_new_sign = 1.0;
T1 log_t_old_sign = 1.0;
T1 log_g_old_sign[6];
for (int i = 0; i < 6; ++i) {
log_g_old_sign[i] = 1.0;
}

std::array<T1, 6> term{0};
for (int k = 0; k <= max_steps; ++k) {
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
if (p == 0) {
return;
}

log_t_new += log(fabs(p)) + log_z;
log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
if constexpr (grad_a1) {
term[0]
= log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term[0]));
log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
}

if constexpr (grad_a2) {
term[1]
= log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term[1]));
log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
}

if constexpr (grad_a3) {
term[2]
= log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term[2]));
log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
}

if constexpr (grad_b1) {
term[3]
= log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term[3]));
log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
}

if constexpr (grad_b2) {
term[4]
= log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term[4]));
log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
}

// g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
+ inv(a1 + k);
log_g_old[0] = log_t_new + log(fabs(term));
log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
+ inv(a2 + k);
log_g_old[1] = log_t_new + log(fabs(term));
log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
+ inv(a3 + k);
log_g_old[2] = log_t_new + log(fabs(term));
log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
- inv(b1 + k);
log_g_old[3] = log_t_new + log(fabs(term));
log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
- inv(b2 + k);
log_g_old[4] = log_t_new + log(fabs(term));
log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

// g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term));
log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;

for (int i = 0; i < 6; ++i) {
g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
if constexpr (grad_z) {
term[5]
= log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
+ inv(z);
log_g_old[5] = log_t_new + log(fabs(term[5]));
log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
}

if (log_t_new <= log(precision)) {
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
typename T_Rtn = return_type_t<Ta, Tb, Tz>,
typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
const Tb& b, const Tz& z,
double precision = 1e-14,
int max_steps = 1e6) {
inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val,
const Ta& a, const Tb& b,
const Tz& z,
double precision = 1e-14,
int max_steps = 1e6) {
using std::max;
using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
Expand Down
47 changes: 23 additions & 24 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,34 @@ namespace stan {
namespace math {
namespace internal {
template <typename Ta, typename Tb, typename Tz,
typename T_return = return_type_t<Ta, Tb, Tz>,
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
ArrayAT a_array = as_array_or_scalar(a);
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
inline return_type_t<Ta, Tb, Tz> hypergeometric_3F2_infsum(
const Ta& a, const Tb& b, const Tz& z, double precision = 1e-6,
int max_steps = 1e5) {
using T_return = return_type_t<Ta, Tb, Tz>;
Eigen::Array<scalar_type_t<Ta>, 3, 1> a_array = as_array_or_scalar(a);
Eigen::Array<scalar_type_t<Tb>, 3, 1> b_array
= append_row(as_array_or_scalar(b), 1.0);
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
b_array[0], b_array[1], z);

T_return t_acc = 1.0;
T_return log_t = 0.0;
T_return log_z = log(fabs(z));
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
plain_type_t<decltype(a_array)> apk = a_array;
plain_type_t<decltype(b_array)> bpk = b_array;
auto log_z = log(fabs(z));
Eigen::Array<int, 3, 1> a_signs = sign(value_of_rec(a_array));
Eigen::Array<int, 3, 1> b_signs = sign(value_of_rec(b_array));
int z_sign = sign(value_of_rec(z));
int t_sign = z_sign * a_signs.prod() * b_signs.prod();

int k = 0;
while (k <= max_steps && log_t >= log(precision)) {
const double log_precision = log(precision);
while (k <= max_steps && log_t >= log_precision) {
// Replace zero values with 1 prior to taking the log so that we accumulate
// 0.0 rather than -inf
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
const auto& abs_apk = math::fabs((a_array == 0).select(1.0, a_array));
const auto& abs_bpk = math::fabs((b_array == 0).select(1.0, b_array));
auto p = sum(log(abs_apk)) - sum(log(abs_bpk));
if (p == NEGATIVE_INFTY) {
return t_acc;
}
Expand All @@ -59,10 +57,10 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
"overflow hypergeometric function did not converge.");
}
k++;
apk.array() += 1.0;
bpk.array() += 1.0;
a_signs = sign(value_of_rec(apk));
b_signs = sign(value_of_rec(bpk));
a_array += 1.0;
b_array += 1.0;
a_signs = sign(value_of_rec(a_array));
b_signs = sign(value_of_rec(b_array));
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
}
if (k == max_steps) {
Expand Down Expand Up @@ -115,7 +113,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
Expand Down Expand Up @@ -143,8 +141,9 @@ auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
*/
template <typename Ta, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b, const Tz& z) {
inline auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
const std::initializer_list<Tb>& b,
const Tz& z) {
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
}

Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/prob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <stan/math/prim/prob/beta_lccdf.hpp>
#include <stan/math/prim/prob/beta_lcdf.hpp>
#include <stan/math/prim/prob/beta_lpdf.hpp>
#include <stan/math/prim/prob/beta_neg_binomial_lccdf.hpp>
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
#include <stan/math/prim/prob/beta_proportion_ccdf_log.hpp>
#include <stan/math/prim/prob/beta_proportion_cdf_log.hpp>
Expand Down
Loading