-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adds boost root finders with reverse mode specializations #2720
base: develop
Are you sure you want to change the base?
Changes from all commits
c33ef74
b015dae
bd9fc69
9b2b477
d729001
f2f8c51
04dbda8
c0d4a81
b08f3ad
33e0b50
f468582
78f84c5
5bff38b
f5c762b
26ca3fc
4c6956f
91a748c
ee38928
344a450
ef8d4e0
4932068
19588e4
11c4a03
13a8886
8cd506b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef STAN_MATH_FWD_FUN_FREXP_HPP | ||
#define STAN_MATH_FWD_FUN_FREXP_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename T> | ||
inline auto frexp(const fvar<T>& x, int* exponent) noexcept { | ||
return std::frexp(value_of_rec(x), exponent); | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef STAN_MATH_FWD_FUN_SIGN_HPP | ||
#define STAN_MATH_FWD_FUN_SIGN_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename T> | ||
inline auto sign(const fvar<T>& x) { | ||
double z = value_of_rec(x); | ||
return (z == 0) ? 0 : z < 0 ? -1 : 1; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP | ||
|
||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err/check_bounded.hpp> | ||
#include <stan/math/prim/err/check_positive.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <boost/math/tools/roots.hpp> | ||
#include <tuple> | ||
#include <utility> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
template <bool ReturnDerivs, typename FRootFunc, typename... Args, | ||
std::enable_if_t<ReturnDerivs>* = nullptr> | ||
inline auto make_root_func(Args&&... args) { | ||
return [&args...](auto&& x) { | ||
return std::decay_t<FRootFunc>::template run<ReturnDerivs>(x, args...); | ||
}; | ||
} | ||
|
||
template <bool ReturnDerivs, typename FRootFunc, | ||
std::enable_if_t<!ReturnDerivs>* = nullptr> | ||
inline auto make_root_func() { | ||
return [](auto&&... args) { | ||
return std::decay_t<FRootFunc>::template run<ReturnDerivs>(args...); | ||
}; | ||
} | ||
|
||
struct NewtonRootSolver { | ||
template <typename... Types> | ||
static inline auto run(Types&&... args) { | ||
return boost::math::tools::newton_raphson_iterate( | ||
std::forward<Types>(args)...); | ||
} | ||
}; | ||
|
||
struct HalleyRootSolver { | ||
template <typename... Types> | ||
static inline auto run(Types&&... args) { | ||
return boost::math::tools::halley_iterate(std::forward<Types>(args)...); | ||
} | ||
}; | ||
|
||
struct SchroderRootSolver { | ||
template <typename... Types> | ||
static inline auto run(Types&&... args) { | ||
return boost::math::tools::schroder_iterate(std::forward<Types>(args)...); | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
|
||
/** | ||
* Solve for root using Boost's Halley method | ||
* @tparam FRootFunc A struct or class with a static function called `run`. | ||
* The structs `run` function must have a boolean template parameter that | ||
* when `true` returns a tuple containing the function result and the | ||
* derivatives needed for the root finder. When the boolean template parameter | ||
* is `false` the function should return a single value containing the function | ||
* result. | ||
* @tparam SolverFun One of the three struct types used to call the root solver. | ||
* (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). | ||
* @tparam GuessScalar Scalar type | ||
* @tparam MinScalar Scalar type | ||
* @tparam MaxScalar Scalar type | ||
* @tparam Types Arg types to pass to functors in `f_tuple` | ||
* @param guess An initial guess at the root value | ||
* @param min The minimum possible value for the result, this is used as an | ||
* initial lower bracket | ||
* @param max The maximum possible value for the result, this is used as an | ||
* initial upper bracket | ||
* @param digits The desired number of binary digits precision | ||
* @param max_iter An optional maximum number of iterations to perform. On exit, | ||
* this is updated to the actual number of iterations performed | ||
* @param args Parameter pack of arguments to pass the the functors in `f_tuple` | ||
*/ | ||
template <typename FRootFunc, typename SolverFun, typename GuessScalar, | ||
typename MinScalar, typename MaxScalar, typename... Types, | ||
require_all_not_st_var<GuessScalar, MinScalar, MaxScalar, | ||
Types...>* = nullptr> | ||
inline auto root_finder_tol(const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, const int digits, | ||
std::uintmax_t& max_iter, Types&&... args) { | ||
check_bounded("root_finder", "initial guess", guess, min, max); | ||
check_positive("root_finder", "digits", digits); | ||
check_positive("root_finder", "max_iter", max_iter); | ||
using ret_t = return_type_t<GuessScalar, MinScalar, MaxScalar, Types...>; | ||
ret_t ret = 0; | ||
auto f_plus_div | ||
= internal::make_root_func<true, FRootFunc>(std::forward<Types>(args)...); | ||
try { | ||
ret = std::decay_t<SolverFun>::run(f_plus_div, ret_t(guess), ret_t(min), | ||
ret_t(max), digits, max_iter); | ||
} catch (const std::exception& e) { | ||
throw e; | ||
} | ||
return ret; | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_halley_tol(const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, const int digits, | ||
std::uintmax_t& max_iter, Types&&... args) { | ||
return root_finder_tol<FRootFunc, internal::HalleyRootSolver>( | ||
guess, min, max, digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_newton_raphson_tol( | ||
const GuessScalar guess, const MinScalar min, const MaxScalar max, | ||
const int digits, std::uintmax_t& max_iter, Types&&... args) { | ||
return root_finder_tol<FRootFunc, internal::NewtonRootSolver>( | ||
guess, min, max, digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_schroder_tol(const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
const int digits, std::uintmax_t& max_iter, | ||
Types&&... args) { | ||
return root_finder_tol<FRootFunc, internal::SchroderRootSolver>( | ||
guess, min, max, digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
/** | ||
* Solve for root with default values for the tolerances | ||
* @tparam FRootFunc A struct or class with a static function called `run`. | ||
* The structs `run` function must have a boolean template parameter that | ||
* when `true` returns a tuple containing the function result and the | ||
* derivatives needed for the root finder. When the boolean template parameter | ||
* is `false` the function should return a single value containing the function | ||
* result. | ||
* @tparam SolverFun One of the three struct types used to call the root solver. | ||
* (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`). | ||
* @tparam GuessScalar Scalar type | ||
* @tparam MinScalar Scalar type | ||
* @tparam MaxScalar Scalar type | ||
* @tparam Types Arg types to pass to functors in `f_tuple` | ||
* @param guess An initial guess at the root value | ||
* @param min The minimum possible value for the result, this is used as an | ||
* initial lower bracket | ||
* @param max The maximum possible value for the result, this is used as an | ||
* initial upper bracket | ||
* @param args Parameter pack of arguments to pass the the functors in `f_tuple` | ||
*/ | ||
template <typename FRootFunc, typename SolverFun, typename GuessScalar, | ||
typename MinScalar, typename MaxScalar, typename... Types> | ||
inline auto root_finder(const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, Types&&... args) { | ||
constexpr int digits = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how was this default chosen? Maybe add one sentence about this choice in the doxygen doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to update this to be in line with what boost's docs say
https://www.boost.org/doc/libs/1_62_0/libs/math/doc/html/math_toolkit/roots/roots_deriv.html |
||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_tol<FRootFunc, SolverFun>( | ||
guess, min, max, digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_hailey(const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_halley_tol<FRootFunc>(guess, min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_newton_raphson(const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_newton_raphson_tol<FRootFunc>( | ||
guess, min, max, digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FRootFunc, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
inline auto root_finder_schroder(const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_schroder_tol<FRootFunc>(guess, min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#ifndef STAN_MATH_REV_FUN_FREXP_HPP | ||
#define STAN_MATH_REV_FUN_FREXP_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
inline auto frexp(stan::math::var x, int* exponent) noexcept { | ||
return std::frexp(x.val(), exponent); | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#ifndef STAN_MATH_REV_FUN_SIGN_HPP | ||
#define STAN_MATH_REV_FUN_SIGN_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
inline int sign(stan::math::var z) { return (z == 0) ? 0 : z < 0 ? -1 : 1; } | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indicate that digits cannot exceed the precision of
f_tuple
.