Skip to content

Commit

Permalink
refactored analyze/mcmc fns and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Oct 19, 2024
1 parent 8d0bd5b commit 99e8f71
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 137 deletions.
22 changes: 17 additions & 5 deletions src/stan/analyze/mcmc/mcse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_ANALYZE_MCMC_MCSE_HPP

#include <stan/analyze/mcmc/check_chains.hpp>
#include <stan/analyze/mcmc/split_chains.hpp>
#include <stan/analyze/mcmc/ess.hpp>
#include <stan/math/prim.hpp>
#include <cmath>
Expand Down Expand Up @@ -42,11 +43,22 @@ inline double mcse_sd(const Eigen::MatrixXd& chains) {
if (chains.rows() < 4 || !is_finite_and_varies(chains))
return std::numeric_limits<double>::quiet_NaN();

Eigen::MatrixXd diffs = (chains.array() - chains.mean()).matrix();
double Evar = diffs.array().square().mean();
double varvar = (math::mean(diffs.array().pow(4) - Evar * Evar))
/ ess(diffs.array().abs().matrix());
return std::sqrt(varvar / Evar / 4);
// center the data, take abs value
Eigen::MatrixXd draws_ctr = (chains.array() - chains.mean()).abs().matrix();

// posterior pkg fn `ess_mean` computes on split chains
double ess_mean = ess(split_chains(draws_ctr));

// estimated variance (2nd moment)
double Evar = draws_ctr.array().square().mean();

// variance of variance, adjusted for ESS
double fourth_moment = draws_ctr.array().pow(4).mean();
double varvar = (fourth_moment - std::pow(Evar, 2)) / ess_mean;

// variance of standard deviation - use Taylor series approximation
double varsd = varvar / Evar / 4.0;
return std::sqrt(varsd);
}

} // namespace analyze
Expand Down
55 changes: 34 additions & 21 deletions src/test/unit/analyze/mcmc/ess_basic_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,42 @@
#include <string>
#include <cmath>

TEST(RankNormalizedEss, test_basic_ess) {
std::stringstream out;
Eigen::MatrixXd chains_lp(1000, 4);
Eigen::MatrixXd chains_theta(1000, 4);

std::vector<const double*> draws_theta(4);
std::vector<const double*> draws_lp(4);
std::vector<size_t> sizes(4);

for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
class EssBasic : public testing::Test {
public:
void SetUp() {
chains_lp.resize(1000, 4);
chains_theta.resize(1000, 4);
draws_theta.resize(4);
draws_lp.resize(4);
sizes.resize(4);
for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
= stan::io::stan_csv_reader::parse(bern_stream, &out);
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
draws_lp[i] = chains_lp.col(i).data();
draws_theta[i] = chains_theta.col(i).data();
sizes[i] = 1000;
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
draws_lp[i] = chains_lp.col(i).data();
draws_theta[i] = chains_theta.col(i).data();
sizes[i] = 1000;
}
}

void TearDown() {
}

std::stringstream out;
Eigen::MatrixXd chains_lp;
Eigen::MatrixXd chains_theta;
std::vector<const double*> draws_theta;
std::vector<const double*> draws_lp;
std::vector<size_t> sizes;
};

TEST_F(EssBasic, test_basic_ess) {
double ess_lp_expect = 1335.4137;
double ess_theta_expect = 1377.503;

Expand Down
72 changes: 72 additions & 0 deletions src/test/unit/analyze/mcmc/mcse_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include <stan/analyze/mcmc/mcse.hpp>
#include <stan/io/stan_csv_reader.hpp>
#include <gtest/gtest.h>
#include <fstream>
#include <sstream>
#include <string>
#include <cmath>

class MonteCarloStandardError : public testing::Test {
public:
void SetUp() {
chains_lp.resize(1000, 4);
chains_theta.resize(1000, 4);
chains_divergent.resize(1000, 4);
for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
= stan::io::stan_csv_reader::parse(bern_stream, &out);
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
chains_divergent.col(i) = bern_csv.samples.col(5);
}
}

void TearDown() {
}

std::stringstream out;
Eigen::MatrixXd chains_lp;
Eigen::MatrixXd chains_theta;
Eigen::MatrixXd chains_divergent;
};

TEST_F(MonteCarloStandardError, test_mcse) {
double mcse_mean_lp_expect = 0.020164778;
double mcse_mean_theta_expect = 0.0032339916;

double mcse_sd_lp_expect = 0.0355305;
double mcse_sd_theta_expect = 0.0021642137;
EXPECT_NEAR(mcse_mean_lp_expect, stan::analyze::mcse_mean(chains_lp), 0.0001);
EXPECT_NEAR(mcse_mean_theta_expect, stan::analyze::mcse_mean(chains_theta), 0.0001);

EXPECT_NEAR(mcse_sd_lp_expect, stan::analyze::mcse_sd(chains_lp), 0.0001);
EXPECT_NEAR(mcse_sd_theta_expect, stan::analyze::mcse_sd(chains_theta), 0.0001);
}

TEST_F(MonteCarloStandardError, const_fail) {
auto mcse_mean = stan::analyze::mcse_mean(chains_divergent);
auto mcse_sd = stan::analyze::mcse_sd(chains_divergent);
EXPECT_TRUE(std::isnan(mcse_mean));
EXPECT_TRUE(std::isnan(mcse_sd));
}

TEST_F(MonteCarloStandardError, inf_fail) {
chains_theta(0,0) = std::numeric_limits<double>::infinity();
auto mcse_mean = stan::analyze::mcse_mean(chains_theta);
auto mcse_sd = stan::analyze::mcse_sd(chains_theta);
EXPECT_TRUE(std::isnan(mcse_mean));
EXPECT_TRUE(std::isnan(mcse_sd));
}

TEST_F(MonteCarloStandardError, short_chains_fail) {
chains_theta.resize(3, 4);
auto mcse_mean = stan::analyze::mcse_mean(chains_theta);
auto mcse_sd = stan::analyze::mcse_sd(chains_theta);
EXPECT_TRUE(std::isnan(mcse_mean));
EXPECT_TRUE(std::isnan(mcse_sd));
}
61 changes: 61 additions & 0 deletions src/test/unit/analyze/mcmc/rhat_basic_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <stan/analyze/mcmc/compute_potential_scale_reduction.hpp>
#include <stan/analyze/mcmc/rhat.hpp>
#include <stan/io/stan_csv_reader.hpp>
#include <gtest/gtest.h>
#include <fstream>
#include <sstream>

class RhatBasic : public testing::Test {
public:
void SetUp() {
chains_lp.resize(1000, 4);
chains_theta.resize(1000, 4);
draws_theta.resize(4);
draws_lp.resize(4);
sizes.resize(4);
for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
= stan::io::stan_csv_reader::parse(bern_stream, &out);
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
draws_lp[i] = chains_lp.col(i).data();
draws_theta[i] = chains_theta.col(i).data();
sizes[i] = 1000;
}
}

void TearDown() {
}

std::stringstream out;
Eigen::MatrixXd chains_lp;
Eigen::MatrixXd chains_theta;
std::vector<const double*> draws_theta;
std::vector<const double*> draws_lp;
std::vector<size_t> sizes;
};

TEST_F(RhatBasic, test_basic_rhat) {
double rhat_lp_basic_expect = 1.0001296;
double rhat_theta_basic_expect = 1.0029197;

auto rhat_basic_lp = stan::analyze::rhat(chains_lp);
auto old_rhat_basic_lp
= stan::analyze::compute_potential_scale_reduction(draws_lp, sizes);

auto rhat_basic_theta = stan::analyze::rhat(chains_theta);
auto old_rhat_basic_theta
= stan::analyze::compute_potential_scale_reduction(draws_theta, sizes);

EXPECT_NEAR(rhat_lp_basic_expect, rhat_basic_lp, 0.00001);
EXPECT_NEAR(rhat_theta_basic_expect, rhat_basic_theta, 0.00001);

EXPECT_NEAR(old_rhat_basic_lp, rhat_basic_lp, 0.00001);
EXPECT_NEAR(old_rhat_basic_theta, rhat_basic_theta, 0.00001);
}

90 changes: 42 additions & 48 deletions src/test/unit/analyze/mcmc/split_rank_normalized_ess_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,36 @@
#include <string>
#include <cmath>

TEST(RankNormalizedEss, test_bulk_tail_ess) {
std::stringstream out;
Eigen::MatrixXd chains_lp(1000, 4);
Eigen::MatrixXd chains_theta(1000, 4);

std::vector<const double*> draws_theta(4);
std::vector<const double*> draws_lp(4);
std::vector<size_t> sizes(4);

for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
class RankNormalizedEss : public testing::Test {
public:
void SetUp() {
chains_lp.resize(1000, 4);
chains_theta.resize(1000, 4);
chains_divergent.resize(1000, 4);
for (size_t i = 0; i < 4; ++i) {
std::stringstream fname;
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1)
<< ".csv";
std::ifstream bern_stream(fname.str(), std::ifstream::in);
stan::io::stan_csv bern_csv
= stan::io::stan_csv_reader::parse(bern_stream, &out);
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
draws_lp[i] = chains_lp.col(i).data();
draws_theta[i] = chains_theta.col(i).data();
sizes[i] = 1000;
bern_stream.close();
chains_lp.col(i) = bern_csv.samples.col(0);
chains_theta.col(i) = bern_csv.samples.col(7);
chains_divergent.col(i) = bern_csv.samples.col(5);
}
}

void TearDown() {
}

std::stringstream out;
Eigen::MatrixXd chains_lp;
Eigen::MatrixXd chains_theta;
Eigen::MatrixXd chains_divergent;
};

TEST_F(RankNormalizedEss, test_bulk_tail_ess) {
double ess_lp_bulk_expect = 1512.7684;
double ess_lp_tail_expect = 1591.9707;

Expand All @@ -45,7 +52,20 @@ TEST(RankNormalizedEss, test_bulk_tail_ess) {
EXPECT_NEAR(ess_theta_tail_expect, ess_theta.second, 0.001);
}

TEST(RankNormalizedEss, short_chains_fail) {
TEST_F(RankNormalizedEss, const_fail) {
auto ess = stan::analyze::split_rank_normalized_ess(chains_divergent);
EXPECT_TRUE(std::isnan(ess.first));
EXPECT_TRUE(std::isnan(ess.second));
}

TEST_F(RankNormalizedEss, inf_fail) {
chains_theta(0,0) = std::numeric_limits<double>::infinity();
auto ess = stan::analyze::split_rank_normalized_ess(chains_theta);
EXPECT_TRUE(std::isnan(ess.first));
EXPECT_TRUE(std::isnan(ess.second));
}

TEST_F(RankNormalizedEss, short_chains_fail) {
std::stringstream out;
std::ifstream eight_schools_5iters_1_stream, eight_schools_5iters_2_stream;
stan::io::stan_csv eight_schools_5iters_1, eight_schools_5iters_2;
Expand All @@ -72,29 +92,3 @@ TEST(RankNormalizedEss, short_chains_fail) {
}
}

TEST(RankNormalizedEss, const_fail) {
std::stringstream out;
std::ifstream bernoulli_const_1_stream, bernoulli_const_2_stream;
stan::io::stan_csv bernoulli_const_1, bernoulli_const_2;
bernoulli_const_1_stream.open(
"src/test/unit/mcmc/test_csv_files/bernoulli_const_1.csv",
std::ifstream::in);
bernoulli_const_1
= stan::io::stan_csv_reader::parse(bernoulli_const_1_stream, &out);
bernoulli_const_1_stream.close();
bernoulli_const_2_stream.open(
"src/test/unit/mcmc/test_csv_files/bernoulli_const_2.csv",
std::ifstream::in);
bernoulli_const_2
= stan::io::stan_csv_reader::parse(bernoulli_const_2_stream, &out);
bernoulli_const_2_stream.close();

Eigen::MatrixXd chains(bernoulli_const_1.samples.rows(), 2);
chains.col(0)
= bernoulli_const_1.samples.col(bernoulli_const_1.samples.cols() - 1);
chains.col(1)
= bernoulli_const_2.samples.col(bernoulli_const_2.samples.cols() - 1);
auto ess = stan::analyze::split_rank_normalized_ess(chains);
EXPECT_TRUE(std::isnan(ess.first));
EXPECT_TRUE(std::isnan(ess.second));
}
Loading

0 comments on commit 99e8f71

Please sign in to comment.