diff --git a/stan/math/fwd/fun.hpp b/stan/math/fwd/fun.hpp index 0237d1a8af4..e529fcb6cd3 100644 --- a/stan/math/fwd/fun.hpp +++ b/stan/math/fwd/fun.hpp @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -103,6 +104,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/fwd/fun/abs.hpp b/stan/math/fwd/fun/abs.hpp index 1f50b149fd6..5817ca62795 100644 --- a/stan/math/fwd/fun/abs.hpp +++ b/stan/math/fwd/fun/abs.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace stan { diff --git a/stan/math/fwd/fun/frexp.hpp b/stan/math/fwd/fun/frexp.hpp new file mode 100644 index 00000000000..aab3f108e28 --- /dev/null +++ b/stan/math/fwd/fun/frexp.hpp @@ -0,0 +1,17 @@ +#ifndef STAN_MATH_FWD_FUN_FREXP_HPP +#define STAN_MATH_FWD_FUN_FREXP_HPP + +#include +#include +#include + +namespace stan { +namespace math { + +template +inline auto frexp(const fvar& x, int* exponent) noexcept { + return std::frexp(value_of_rec(x), exponent); +} +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/fwd/fun/sign.hpp b/stan/math/fwd/fun/sign.hpp new file mode 100644 index 00000000000..fcdfb1805a8 --- /dev/null +++ b/stan/math/fwd/fun/sign.hpp @@ -0,0 +1,19 @@ +#ifndef STAN_MATH_FWD_FUN_SIGN_HPP +#define STAN_MATH_FWD_FUN_SIGN_HPP + +#include +#include +#include + +namespace stan { +namespace math { + +template +inline auto sign(const fvar& x) { + double z = value_of_rec(x); + return (z == 0) ? 0 : z < 0 ? -1 : 1; +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index 0ec5c343ff7..47c7116d257 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -26,5 +26,5 @@ #include #include #include - +#include #endif diff --git a/stan/math/prim/functor/root_finder.hpp b/stan/math/prim/functor/root_finder.hpp new file mode 100644 index 00000000000..81e14456305 --- /dev/null +++ b/stan/math/prim/functor/root_finder.hpp @@ -0,0 +1,195 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP +#define STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { +namespace internal { +template * = nullptr> +inline auto make_root_func(Args&&... args) { + return [&args...](auto&& x) { + return std::decay_t::template run(x, args...); + }; +} + +template * = nullptr> +inline auto make_root_func() { + return [](auto&&... args) { + return std::decay_t::template run(args...); + }; +} + +struct NewtonRootSolver { + template + static inline auto run(Types&&... args) { + return boost::math::tools::newton_raphson_iterate( + std::forward(args)...); + } +}; + +struct HalleyRootSolver { + template + static inline auto run(Types&&... args) { + return boost::math::tools::halley_iterate(std::forward(args)...); + } +}; + +struct SchroderRootSolver { + template + static inline auto run(Types&&... args) { + return boost::math::tools::schroder_iterate(std::forward(args)...); + } +}; + +} // namespace internal + +/** + * Solve for root using Boost's Halley method + * @tparam FRootFunc A struct or class with a static function called `run`. + * The structs `run` function must have a boolean template parameter that + * when `true` returns a tuple containing the function result and the + * derivatives needed for the root finder. When the boolean template parameter + * is `false` the function should return a single value containing the function + * result. + * @tparam SolverFun One of the three struct types used to call the root solver. + * (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). + * @tparam GuessScalar Scalar type + * @tparam MinScalar Scalar type + * @tparam MaxScalar Scalar type + * @tparam Types Arg types to pass to functors in `f_tuple` + * @param guess An initial guess at the root value + * @param min The minimum possible value for the result, this is used as an + * initial lower bracket + * @param max The maximum possible value for the result, this is used as an + * initial upper bracket + * @param digits The desired number of binary digits precision + * @param max_iter An optional maximum number of iterations to perform. On exit, + * this is updated to the actual number of iterations performed + * @param args Parameter pack of arguments to pass the the functors in `f_tuple` + */ +template * = nullptr> +inline auto root_finder_tol(const GuessScalar guess, const MinScalar min, + const MaxScalar max, const int digits, + std::uintmax_t& max_iter, Types&&... args) { + check_bounded("root_finder", "initial guess", guess, min, max); + check_positive("root_finder", "digits", digits); + check_positive("root_finder", "max_iter", max_iter); + using ret_t = return_type_t; + ret_t ret = 0; + auto f_plus_div + = internal::make_root_func(std::forward(args)...); + try { + ret = std::decay_t::run(f_plus_div, ret_t(guess), ret_t(min), + ret_t(max), digits, max_iter); + } catch (const std::exception& e) { + throw e; + } + return ret; +} + +template +inline auto root_finder_halley_tol(const GuessScalar guess, const MinScalar min, + const MaxScalar max, const int digits, + std::uintmax_t& max_iter, Types&&... args) { + return root_finder_tol( + guess, min, max, digits, max_iter, std::forward(args)...); +} + +template +inline auto root_finder_newton_raphson_tol( + const GuessScalar guess, const MinScalar min, const MaxScalar max, + const int digits, std::uintmax_t& max_iter, Types&&... args) { + return root_finder_tol( + guess, min, max, digits, max_iter, std::forward(args)...); +} + +template +inline auto root_finder_schroder_tol(const GuessScalar guess, + const MinScalar min, const MaxScalar max, + const int digits, std::uintmax_t& max_iter, + Types&&... args) { + return root_finder_tol( + guess, min, max, digits, max_iter, std::forward(args)...); +} + +/** + * Solve for root with default values for the tolerances + * @tparam FRootFunc A struct or class with a static function called `run`. + * The structs `run` function must have a boolean template parameter that + * when `true` returns a tuple containing the function result and the + * derivatives needed for the root finder. When the boolean template parameter + * is `false` the function should return a single value containing the function + * result. + * @tparam SolverFun One of the three struct types used to call the root solver. + * (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). + * @tparam GuessScalar Scalar type + * @tparam MinScalar Scalar type + * @tparam MaxScalar Scalar type + * @tparam Types Arg types to pass to functors in `f_tuple` + * @param guess An initial guess at the root value + * @param min The minimum possible value for the result, this is used as an + * initial lower bracket + * @param max The maximum possible value for the result, this is used as an + * initial upper bracket + * @param args Parameter pack of arguments to pass the the functors in `f_tuple` + */ +template +inline auto root_finder(const GuessScalar guess, const MinScalar min, + const MaxScalar max, Types&&... args) { + constexpr int digits = 16; + std::uintmax_t max_iter = std::numeric_limits::max(); + return root_finder_tol( + guess, min, max, digits, max_iter, std::forward(args)...); +} + +template +inline auto root_finder_hailey(const GuessScalar guess, const MinScalar min, + const MaxScalar max, Types&&... args) { + constexpr int digits = 16; + std::uintmax_t max_iter = std::numeric_limits::max(); + return root_finder_halley_tol(guess, min, max, digits, max_iter, + std::forward(args)...); +} + +template +inline auto root_finder_newton_raphson(const GuessScalar guess, + const MinScalar min, const MaxScalar max, + Types&&... args) { + constexpr int digits = 16; + std::uintmax_t max_iter = std::numeric_limits::max(); + return root_finder_newton_raphson_tol( + guess, min, max, digits, max_iter, std::forward(args)...); +} + +template +inline auto root_finder_schroder(const GuessScalar guess, const MinScalar min, + const MaxScalar max, Types&&... args) { + constexpr int digits = 16; + std::uintmax_t max_iter = std::numeric_limits::max(); + return root_finder_schroder_tol(guess, min, max, digits, max_iter, + std::forward(args)...); +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index de47503372f..0cb81cf2903 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include #include @@ -161,6 +162,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/fun/frexp.hpp b/stan/math/rev/fun/frexp.hpp new file mode 100644 index 00000000000..300e4a67914 --- /dev/null +++ b/stan/math/rev/fun/frexp.hpp @@ -0,0 +1,14 @@ +#ifndef STAN_MATH_REV_FUN_FREXP_HPP +#define STAN_MATH_REV_FUN_FREXP_HPP + +#include +#include + +namespace stan { +namespace math { +inline auto frexp(stan::math::var x, int* exponent) noexcept { + return std::frexp(x.val(), exponent); +} +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/fun/sign.hpp b/stan/math/rev/fun/sign.hpp new file mode 100644 index 00000000000..ad7e8ff3737 --- /dev/null +++ b/stan/math/rev/fun/sign.hpp @@ -0,0 +1,12 @@ +#ifndef STAN_MATH_REV_FUN_SIGN_HPP +#define STAN_MATH_REV_FUN_SIGN_HPP + +#include +#include + +namespace stan { +namespace math { +inline int sign(stan::math::var z) { return (z == 0) ? 0 : z < 0 ? -1 : 1; } +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/functor.hpp b/stan/math/rev/functor.hpp index d494bfa5c6b..5e55c3a3506 100644 --- a/stan/math/rev/functor.hpp +++ b/stan/math/rev/functor.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #endif diff --git a/stan/math/rev/functor/root_finder.hpp b/stan/math/rev/functor/root_finder.hpp new file mode 100644 index 00000000000..7de5124df05 --- /dev/null +++ b/stan/math/rev/functor/root_finder.hpp @@ -0,0 +1,175 @@ +#ifndef STAN_MATH_REV_FUNCTOR_ROOT_FINDER_HPP +#define STAN_MATH_REV_FUNCTOR_ROOT_FINDER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * var specialization for root solving using Boost's Halley method + * @tparam FRootFunc A struct or class with a static function called `run`. + * The structs `run` function must have a boolean template parameter that + * when `true` returns a tuple containing the function result and the + * derivatives needed for the root finder. When the boolean template parameter + * is `false` the function should return a single value containing the function + * result. + * @tparam SolverFun One of the three struct types used to call the root solver. + * (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). + * @tparam GuessScalar Scalar type + * @tparam MinScalar Scalar type + * @tparam MaxScalar Scalar type + * @tparam Types Arg types to pass to functors in `f_tuple` + * @param guess An initial guess at the root value + * @param min The minimum possible value for the result, this is used as an + * initial lower bracket + * @param max The maximum possible value for the result, this is used as an + * initial upper bracket + * @param digits The desired number of binary digits precision + * @param max_iter An optional maximum number of iterations to perform. On exit, + * this is updated to the actual number of iterations performed + * @param args Parameter pack of arguments to pass the the functors in `f_tuple` + */ +template < + typename FRootFunc, typename SolverFun, typename GuessScalar, + typename MinScalar, typename MaxScalar, typename... Types, + require_any_st_var* = nullptr, + require_all_stan_scalar_t* = nullptr, + require_any_not_stan_scalar_t* = nullptr> +inline auto root_finder_tol(const GuessScalar guess, const MinScalar min, + const MaxScalar max, const int digits, + std::uintmax_t& max_iter, Types&&... args) { + check_bounded("root_finder", "initial guess", guess, min, max); + check_positive("root_finder", "digits", digits); + check_positive("root_finder", "max_iter", max_iter); + auto arena_args_tuple + = make_chainable_ptr(std::make_tuple(eval(std::forward(args))...)); + auto args_vals_tuple = apply( + [&](const auto&... args) { + return std::make_tuple(to_ref(value_of(args))...); + }, + *arena_args_tuple); + // Solve the system + double theta_dbl = apply( + [&max_iter, digits, guess_val = value_of(guess), min_val = value_of(min), + max_val = value_of(max)](auto&&... vals) { + return root_finder_tol( + guess_val, min_val, max_val, digits > 20 ? digits : 21, max_iter, + vals...); + }, + args_vals_tuple); + double Jf_x; + { + nested_rev_autodiff nested; + stan::math::var x_var(theta_dbl); + stan::math::var fx_var = apply( + [&x_var](auto&&... args) { + return std::decay_t::template run( + x_var, std::move(args)...); + }, + std::move(args_vals_tuple)); + fx_var.grad(); + Jf_x = x_var.adj(); + } + + /* + * Note: Because we put this on the callback stack, if `f` is a lambda + * its captures must be in Stan's arena memory or trivially destructable. + */ + return make_callback_var( + theta_dbl, [arena_args_tuple, Jf_x](auto& ret) mutable { + { + nested_rev_autodiff rev; + double eta = -(ret.adj() / Jf_x); + double ret_val = ret.val(); + auto x_nrad_ = apply( + [&ret_val](const auto&... args) { + auto f = internal::make_root_func(); + return eval(std::decay_t::template run( + ret_val, args...)); + }, + *arena_args_tuple); + x_nrad_.adj() = eta; + grad(); + } + }); +} + +/** + * var specialization for root solving using Boost's Halley method + * @tparam FRootFunc A struct or class with a static function called `run`. + * The structs `run` function must have a boolean template parameter that + * when `true` returns a tuple containing the function result and the + * derivatives needed for the root finder. When the boolean template parameter + * is `false` the function should return a single value containing the function + * result. + * @tparam SolverFun One of the three struct types used to call the root solver. + * (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). + * @tparam GuessScalar Scalar type + * @tparam MinScalar Scalar type + * @tparam MaxScalar Scalar type + * @tparam Types Arg types to pass to functors in `f_tuple` + * @param guess An initial guess at the root value + * @param min The minimum possible value for the result, this is used as an + * initial lower bracket + * @param max The maximum possible value for the result, this is used as an + * initial upper bracket + * @param digits The desired number of binary digits precision + * @param max_iter An optional maximum number of iterations to perform. On exit, + * this is updated to the actual number of iterations performed + * @param args Parameter pack of arguments to pass the the functors in `f_tuple` + */ +template < + typename FRootFunc, typename SolverFun, typename GuessScalar, + typename MinScalar, typename MaxScalar, typename... Types, + require_any_st_var* = nullptr, + require_all_stan_scalar_t* = nullptr, + require_all_stan_scalar_t* = nullptr> +inline auto root_finder_tol(const GuessScalar guess, const MinScalar min, + const MaxScalar max, const int digits, + std::uintmax_t& max_iter, Types... args) { + check_bounded("root_finder", "initial guess", guess, min, max); + check_positive("root_finder", "digits", digits); + check_positive("root_finder", "max_iter", max_iter); + // Solve the system + double theta_dbl = root_finder_tol( + value_of(guess), value_of(min), value_of(max), digits > 20 ? digits : 21, + max_iter, value_of(args)...); + double Jf_x; + { + nested_rev_autodiff nested; + stan::math::var x_var(theta_dbl); + stan::math::var fx_var = std::decay_t::template run( + x_var, value_of(args)...); + fx_var.grad(); + Jf_x = x_var.adj(); + } + + /* + * Note: Because we put this on the callback stack, if `f` is a lambda + * its captures must be in Stan's arena memory or trivially destructable. + */ + return make_callback_var(theta_dbl, [Jf_x, args...](auto& ret) mutable { + { + nested_rev_autodiff rev; + double eta = -(ret.adj() / Jf_x); + double ret_val = ret.val(); + auto f = internal::make_root_func(); + auto x_nrad + = std::decay_t::template run(ret_val, args...); + x_nrad.adj() = eta; + grad(); + } + }); +} + +} // namespace math +} // namespace stan +#endif diff --git a/test/unit/math/mix/functor/root_finder_test.cpp b/test/unit/math/mix/functor/root_finder_test.cpp new file mode 100644 index 00000000000..a87f491072e --- /dev/null +++ b/test/unit/math/mix/functor/root_finder_test.cpp @@ -0,0 +1,126 @@ +#include + +struct CubedRootFinder { + template * = nullptr> + static auto run(T1&& g, T2&& x) { + auto g_pow = g; + auto second_deriv = 20 * g_pow; + g_pow *= g; + auto first_deriv = 3 * g_pow; + g_pow *= g; + auto func_val = g_pow - x; + return std::make_tuple(func_val, first_deriv, second_deriv); + } + template * = nullptr> + static auto run(T1&& g, T2&& x) { + return g * g * g - x; + } +}; + +TEST(MixFun, root_finder_cubed) { + using stan::math::root_finder; + // return cube root of x using 1st and 2nd derivatives and Halley. + // using namespace std; // Help ADL of std functions. + double x = 27; + int exponent; + // Get exponent of z (ignore mantissa). + std::frexp(x, &exponent); + // Rough guess is to divide the exponent by three. + double guess = ldexp(1., exponent / 3); + // Minimum possible value is half our guess. + double min = ldexp(0.5, exponent / 3); + // Maximum possible value is twice our guess. + double max = ldexp(2., exponent / 3); + // Maximum possible binary digits accuracy for type T. + const int digits = std::numeric_limits::digits; + int get_digits = static_cast(digits * 0.4); + std::uintmax_t maxit = 20; + auto full_f = [guess, min, max](auto&& xx) { + return stan::math::root_finder_hailey(guess, min, max, xx); + }; + stan::test::expect_ad(full_f, x); +} + +struct FifthRootFinder { + template * = nullptr> + static auto run(T1&& g, T2&& x) { + auto g_pow = g * g * g; + auto second_deriv = 20 * g_pow; + g_pow *= g; + auto first_deriv = 5 * g_pow; + g_pow *= g; + auto func_val = g_pow - x; + return std::make_tuple(func_val, first_deriv, second_deriv); + } + template * = nullptr> + static auto run(T1&& g, T2&& x) { + return g * g * g * g * g - x; + } +}; + +TEST(MixFun, root_finder_fifth) { + using stan::math::root_finder; + // return cube root of x using 1st and 2nd derivatives and Halley. + // using namespace std; // Help ADL of std functions. + double x = 27; + int exponent; + // Get exponent of z (ignore mantissa). + std::frexp(x, &exponent); + // Rough guess is to divide the exponent by three. + double guess = ldexp(1., exponent / 5); + // Minimum possible value is half our guess. + double min = ldexp(0.5, exponent / 5); + // Maximum possible value is twice our guess. + double max = ldexp(2., exponent / 5); + // Maximum possible binary digits accuracy for type T. + const int digits = std::numeric_limits::digits; + int get_digits = static_cast(digits * 0.4); + std::uintmax_t maxit = 20; + auto f_hailey = [guess, min, max](auto&& x) { + return stan::math::root_finder_hailey(guess, min, max, x); + }; + stan::test::expect_ad(f_hailey, x); +} + +struct BetaCdfRoot { + template * = nullptr> + static auto run(T1&& x, T2&& alpha, T3&& beta, T4&& p) { + auto f_val = stan::math::inc_beta(alpha, beta, x) - p; + auto beta_ab = stan::math::beta(alpha, beta); + auto f_val_d = (stan::math::pow(1.0 - x, -1.0 + beta) + * stan::math::pow(x, -1.0 + alpha)) + / beta_ab; + auto f_val_dd = (stan::math::pow(1 - x, -1 + beta) + * stan::math::pow(x, -2 + alpha) * (-1 + alpha)) + / beta_ab + - (stan::math::pow(1 - x, -2 + beta) + * stan::math::pow(x, -1 + alpha) * (-1 + beta)) + / beta_ab; + + return std::make_tuple(f_val, f_val_d, f_val_dd); + } + template * = nullptr> + static auto run(T1&& x, T2&& alpha, T3&& beta, T4&& p) { + return stan::math::inc_beta(alpha, beta, x) - p; + } +}; + +TEST(MixFun, root_finder_beta) { + constexpr double guess = .5; + constexpr double min = 0; + constexpr double max = 1; + auto f_hailey = [](auto&& alpha, auto&& beta, auto&& p) { + return stan::math::root_finder_hailey(p, min, max, alpha, beta, + p); + }; + double alpha = .5; + double beta = .5; + double p = .3; + stan::test::expect_ad(f_hailey, alpha, beta, p); +} diff --git a/test/unit/math/rev/functor/root_finder_test.cpp b/test/unit/math/rev/functor/root_finder_test.cpp new file mode 100644 index 00000000000..971fc9eefd1 --- /dev/null +++ b/test/unit/math/rev/functor/root_finder_test.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include +#include +#include +#include + +struct BetaCdfRoot { + template * = nullptr> + static auto run(T1&& x, T2&& alpha, T3&& beta, T4&& p) { + auto f_val = boost::math::ibeta(alpha, beta, x) - p; + auto beta_ab = boost::math::beta(alpha, beta); + double f_val_d + = (std::pow(1.0 - x, -1.0 + beta) * std::pow(x, -1.0 + alpha)) + / beta_ab; + double f_val_dd + = (std::pow(1 - x, -1 + beta) * std::pow(x, -2 + alpha) * (-1 + alpha)) + / beta_ab + - (std::pow(1 - x, -2 + beta) * std::pow(x, -1 + alpha) * (-1 + beta)) + / beta_ab; + + return std::make_tuple(f_val, f_val_d, f_val_dd); + } + template * = nullptr> + static auto run(T1&& x, T2&& alpha, T3&& beta, T4&& p) { + return stan::math::inc_beta(alpha, beta, x) - p; + } +}; +TEST(RevFunctor, root_finder_beta_cdf) { + using stan::math::var; + auto func = [](auto&& vals) { + auto p = vals(0); + auto alpha = vals(1); + auto beta = vals(2); + boost::math::beta_distribution my_beta(alpha, beta); + return boost::math::quantile(my_beta, p); + }; + double p = 0.4; + double a = 0.5; + double b = 0.5; + Eigen::VectorXd vals(3); + // p, alpha, beta + vals << p, a, b; + double fx = 0; + Eigen::VectorXd finit_grad_fx(3); + stan::math::finite_diff_gradient(func, vals, fx, finit_grad_fx, 1e-3); + std::cout << "--- Finit Diff----\n"; + std::cout << "fx: " << fx; + std::cout << "\ngrads: \n" + << "p: " << finit_grad_fx(0) + << "\n" + "alpha: " + << finit_grad_fx(1) + << "\n" + "beta: " + << finit_grad_fx(2) << "\n"; + double guess = .3; + double min = 0; + double max = 1; + auto full_f = [guess, min, max](auto&& alpha, auto&& beta, auto&& p) { + std::uintmax_t max_its = 1000; + return stan::math::root_finder_hailey(guess, min, max, alpha, + beta, p); + }; + auto func2 = [&full_f](auto&& vals) { + auto p = vals(0); + auto alpha = vals(1); + auto beta = vals(2); + return full_f(alpha, beta, p); + }; + Eigen::VectorXd grad_fx(3); + Eigen::Matrix var_vec(vals); + stan::math::var fxvar = func2(var_vec); + fxvar.grad(); + grad_fx = var_vec.adj(); + fx = fxvar.val(); + std::cout << "fxvar adj:" << fxvar.adj() << "\n"; + // stan::math::gradient(func2, vals, fx, grad_fx); + std::cout << "--- Auto Diff----\n"; + std::cout << "fx: " << fx; + std::cout << "\ngrads: \n" + << "p: " << grad_fx(0) + << "\n" + "alpha: " + << grad_fx(1) + << "\n" + "beta: " + << grad_fx(2) << "\n"; + Eigen::VectorXd diff_grad_fx = finit_grad_fx - grad_fx; + std::cout << "--- grad diffs----\n"; + std::cout << "p: " << diff_grad_fx(0) + << "\n" + "alpha: " + << diff_grad_fx(1) + << "\n" + "beta: " + << diff_grad_fx(2) << "\n"; + auto deriv_p = [](auto& p, auto& a, auto& b) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::beta; + return beta(a, b) * pow(1 - ibeta_inv(a, b, p), (1 - b)) + * pow(ibeta_inv(a, b, p), (1 - a)); + }; + auto deriv_a = [](auto& p, auto& a, auto& b) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::ibeta; + using boost::math::beta; + using boost::math::tgamma; + using boost::math::hypergeometric_pFq; + using boost::math::beta; + using boost::math::polygamma; + using std::log; + double w = ibeta_inv(a, b, p); + return pow(1 - w, (1 - b)) * pow(w, (1 - a)) + * (pow(w, a) * pow(tgamma(a), 2) + * (hypergeometric_pFq({a, a, 1 - b}, {1 + a, 1 + a}, w) + / (tgamma(1 + a) * tgamma(1 + a))) + - beta(a, b) * ibeta(a, b, w) + * (log(w) - polygamma(0, a) + polygamma(0, a + b))); + }; + + auto deriv_b = [](auto& p, auto& a, auto& b) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::ibeta; + using boost::math::beta; + using boost::math::tgamma; + using boost::math::hypergeometric_pFq; + using boost::math::beta; + using boost::math::polygamma; + using std::log; + return pow(1 - ibeta_inv(a, b, p), -b) * (-1 + ibeta_inv(a, b, p)) + * pow(ibeta_inv(a, b, p), (1 - a)) + * (pow(tgamma(b), 2) + * (hypergeometric_pFq({b, b, 1 - a}, {1 + b, 1 + b}, + 1 - ibeta_inv(a, b, p)) + / (tgamma(1 + b) * tgamma(1 + b))) + * pow(1 - ibeta_inv(a, b, p), b) + - beta(b, a, 1 - ibeta_inv(a, b, p)) + * (log(1 - ibeta_inv(a, b, p)) - polygamma(0, b) + + polygamma(0, a + b))); + }; + double known_p_grad = deriv_p(p, a, b); + double known_alpha_grad = deriv_a(p, a, b); + double known_beta_grad = deriv_b(p, a, b); + std::cout << "--- Mathematica Calculate Grad----\n"; + std::cout << "p: " << known_p_grad << "\n" + << "alpha: " << known_alpha_grad << "\n" + << "beta: " << known_beta_grad << "\n"; + std::cout << "--- Mathematica Calculate Grad Diff----\n"; + std::cout << "p: " << grad_fx(0) - known_p_grad << "\n" + << "alpha: " << grad_fx(1) - known_alpha_grad << "\n" + << "beta: " << grad_fx(2) - known_beta_grad << "\n"; +} + +template +inline constexpr auto make_index_tuple(const std::index_sequence&) { + return std::make_tuple(I...); +} +template +void check_vs_known_grads(FTuple&& grad_tuple, FGrad&& f, double tolerance, + Args&&... args) { + try { + const stan::math::nested_rev_autodiff nested; + auto var_tuple = std::make_tuple(stan::math::var(args)...); + // For pretty printer index of incorrect values + auto arg_num_tuple = stan::math::apply( + [&args...](auto&&... nums) { + return std::make_tuple(std::make_tuple(args, nums)...); + }, + make_index_tuple(std::make_index_sequence())); + + auto ret = stan::math::apply( + [&f](auto&&... var_args) { return f(var_args...); }, var_tuple); + ret.grad(); + auto grad_val_tuple = stan::math::apply( + [&args...](auto&&... grad_funcs) { + return std::make_tuple(grad_funcs(args...)...); + }, + grad_tuple); + auto adj_tuple = stan::math::apply( + [](auto&&... var_arg) { return std::make_tuple(var_arg.adj()...); }, + var_tuple); + stan::math::for_each( + [tolerance](auto&& grad, auto&& adj, auto&& num_helper) { + EXPECT_NEAR(grad, adj, tolerance) + << "Diff: (" << grad - adj << ")\nWith arg #" + << std::get<1>(num_helper) << " (" << std::get<0>(num_helper) + << ")"; + }, + grad_val_tuple, adj_tuple, arg_num_tuple); + } catch (const std::exception& e) { + stan::math::recover_memory(); + } + stan::math::recover_memory(); +} +TEST(RevFunctor, root_finder_beta_cdf2) { + constexpr double guess = .5; + constexpr double min = 0; + constexpr double max = 1; + + auto deriv_a = [](auto&& a, auto&& b, auto&& p) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::ibeta; + using boost::math::beta; + using boost::math::tgamma; + using boost::math::hypergeometric_pFq; + using boost::math::beta; + using boost::math::polygamma; + using std::log; + double w = ibeta_inv(a, b, p); + return pow(1 - w, (1 - b)) * pow(w, (1 - a)) + * (pow(w, a) * pow(tgamma(a), 2) + * (hypergeometric_pFq({a, a, 1 - b}, {1 + a, 1 + a}, w) + / (tgamma(1 + a) * tgamma(1 + a))) + - beta(a, b) * ibeta(a, b, w) + * (log(w) - polygamma(0, a) + polygamma(0, a + b))); + }; + auto deriv_b = [](auto&& a, auto&& b, auto&& p) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::ibeta; + using boost::math::beta; + using boost::math::tgamma; + using boost::math::hypergeometric_pFq; + using boost::math::beta; + using boost::math::polygamma; + using std::log; + return pow(1 - ibeta_inv(a, b, p), -b) * (-1 + ibeta_inv(a, b, p)) + * pow(ibeta_inv(a, b, p), (1 - a)) + * (pow(tgamma(b), 2) + * (hypergeometric_pFq({b, b, 1 - a}, {1 + b, 1 + b}, + 1 - ibeta_inv(a, b, p)) + / (tgamma(1 + b) * tgamma(1 + b))) + * pow(1 - ibeta_inv(a, b, p), b) + - beta(b, a, 1 - ibeta_inv(a, b, p)) + * (log(1 - ibeta_inv(a, b, p)) - polygamma(0, b) + + polygamma(0, a + b))); + }; + auto deriv_p = [](auto&& a, auto&& b, auto&& p) { + using std::pow; + using boost::math::ibeta_inv; + using boost::math::beta; + return beta(a, b) * pow(1 - ibeta_inv(a, b, p), (1 - b)) + * pow(ibeta_inv(a, b, p), (1 - a)); + }; + auto f_newton = [guess, min, max](auto&& alpha, auto&& beta, auto&& p) { + return stan::math::root_finder_newton_raphson(guess, min, max, + alpha, beta, p); + }; + auto f_schroder = [guess, min, max](auto&& alpha, auto&& beta, auto&& p) { + constexpr int digits = 16; + std::uintmax_t max_iter = std::numeric_limits::max(); + return stan::math::root_finder_schroder(guess, min, max, alpha, + beta, p); + }; + auto f_hailey = [](auto&& alpha, auto&& beta, auto&& p) { + return stan::math::root_finder_hailey(p, min, max, alpha, beta, + p); + }; + check_vs_known_grads(std::make_tuple(deriv_a, deriv_b, deriv_p), f_schroder, + 5e-2, .5, .5, .3); + + check_vs_known_grads(std::make_tuple(deriv_a, deriv_b, deriv_p), f_schroder, + 5e-2, .5, .6, .5); + + // For some reason this fails after a while?? + for (double p = .1; p < .9; p += .1) { + for (double a = .1; a < .9; a += .1) { + for (double b = .1; b < .9; b += .1) { + check_vs_known_grads(std::make_tuple(deriv_a, deriv_b, deriv_p), + f_schroder, 1e-9, a, b, p); + if (::testing::Test::HasFailure()) { + std::cout << "--\na: " << a << "\nb: " << b << "\np: " << p << "\n"; + } + } + } + } +}