Skip to content

Commit

Permalink
Converted some other samplers over to use accessor functions on dense…
Browse files Browse the repository at this point in the history
… and diagonal metrics (Issue #2881)
  • Loading branch information
bbbales2 committed Oct 28, 2020
1 parent d4dd953 commit 1da6121
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 15 deletions.
7 changes: 5 additions & 2 deletions src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_dense_e_nuts_classic : public dense_e_nuts_classic<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
6 changes: 5 additions & 1 deletion src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ class adapt_diag_e_nuts_classic : public diag_e_nuts_classic<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric,
this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
4 changes: 2 additions & 2 deletions src/stan/mcmc/hmc/nuts_classic/dense_e_nuts_classic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class dense_e_nuts_classic
// here since start.inv_e_metric_ = finish.inv_e_metric_
bool compute_criterion(ps_point& start, dense_e_point& finish,
Eigen::VectorXd& rho) {
return finish.p.transpose() * finish.inv_e_metric_ * (rho - finish.p) > 0
&& start.p.transpose() * finish.inv_e_metric_ * (rho - start.p) > 0;
return finish.p.transpose() * finish.get_inv_metric() * (rho - finish.p) > 0
&& start.p.transpose() * finish.get_inv_metric() * (rho - start.p) > 0;
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/stan/mcmc/hmc/nuts_classic/diag_e_nuts_classic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class diag_e_nuts_classic
// since start.inv_e_metric_ = finish.inv_e_metric_
bool compute_criterion(ps_point& start, diag_e_point& finish,
Eigen::VectorXd& rho) {
return finish.inv_e_metric_.cwiseProduct(finish.p).dot(rho - finish.p) > 0
&& finish.inv_e_metric_.cwiseProduct(start.p).dot(rho - start.p) > 0;
return finish.get_inv_metric().cwiseProduct(finish.p).dot(rho - finish.p) > 0
&& finish.get_inv_metric().cwiseProduct(start.p).dot(rho - start.p) > 0;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ class adapt_dense_e_static_uniform
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ class adapt_diag_e_static_uniform
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);
Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q);
if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
7 changes: 5 additions & 2 deletions src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_dense_e_xhmc : public dense_e_xhmc<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
7 changes: 5 additions & 2 deletions src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_diag_e_xhmc : public diag_e_xhmc<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);
Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down

0 comments on commit 1da6121

Please sign in to comment.