-
-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactored analyze/mcmc fns and unit tests
- Loading branch information
1 parent
8d0bd5b
commit 99e8f71
Showing
6 changed files
with
266 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include <stan/analyze/mcmc/mcse.hpp> | ||
#include <stan/io/stan_csv_reader.hpp> | ||
#include <gtest/gtest.h> | ||
#include <fstream> | ||
#include <sstream> | ||
#include <string> | ||
#include <cmath> | ||
|
||
class MonteCarloStandardError : public testing::Test { | ||
public: | ||
void SetUp() { | ||
chains_lp.resize(1000, 4); | ||
chains_theta.resize(1000, 4); | ||
chains_divergent.resize(1000, 4); | ||
for (size_t i = 0; i < 4; ++i) { | ||
std::stringstream fname; | ||
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1) | ||
<< ".csv"; | ||
std::ifstream bern_stream(fname.str(), std::ifstream::in); | ||
stan::io::stan_csv bern_csv | ||
= stan::io::stan_csv_reader::parse(bern_stream, &out); | ||
bern_stream.close(); | ||
chains_lp.col(i) = bern_csv.samples.col(0); | ||
chains_theta.col(i) = bern_csv.samples.col(7); | ||
chains_divergent.col(i) = bern_csv.samples.col(5); | ||
} | ||
} | ||
|
||
void TearDown() { | ||
} | ||
|
||
std::stringstream out; | ||
Eigen::MatrixXd chains_lp; | ||
Eigen::MatrixXd chains_theta; | ||
Eigen::MatrixXd chains_divergent; | ||
}; | ||
|
||
TEST_F(MonteCarloStandardError, test_mcse) { | ||
double mcse_mean_lp_expect = 0.020164778; | ||
double mcse_mean_theta_expect = 0.0032339916; | ||
|
||
double mcse_sd_lp_expect = 0.0355305; | ||
double mcse_sd_theta_expect = 0.0021642137; | ||
EXPECT_NEAR(mcse_mean_lp_expect, stan::analyze::mcse_mean(chains_lp), 0.0001); | ||
EXPECT_NEAR(mcse_mean_theta_expect, stan::analyze::mcse_mean(chains_theta), 0.0001); | ||
|
||
EXPECT_NEAR(mcse_sd_lp_expect, stan::analyze::mcse_sd(chains_lp), 0.0001); | ||
EXPECT_NEAR(mcse_sd_theta_expect, stan::analyze::mcse_sd(chains_theta), 0.0001); | ||
} | ||
|
||
TEST_F(MonteCarloStandardError, const_fail) { | ||
auto mcse_mean = stan::analyze::mcse_mean(chains_divergent); | ||
auto mcse_sd = stan::analyze::mcse_sd(chains_divergent); | ||
EXPECT_TRUE(std::isnan(mcse_mean)); | ||
EXPECT_TRUE(std::isnan(mcse_sd)); | ||
} | ||
|
||
TEST_F(MonteCarloStandardError, inf_fail) { | ||
chains_theta(0,0) = std::numeric_limits<double>::infinity(); | ||
auto mcse_mean = stan::analyze::mcse_mean(chains_theta); | ||
auto mcse_sd = stan::analyze::mcse_sd(chains_theta); | ||
EXPECT_TRUE(std::isnan(mcse_mean)); | ||
EXPECT_TRUE(std::isnan(mcse_sd)); | ||
} | ||
|
||
TEST_F(MonteCarloStandardError, short_chains_fail) { | ||
chains_theta.resize(3, 4); | ||
auto mcse_mean = stan::analyze::mcse_mean(chains_theta); | ||
auto mcse_sd = stan::analyze::mcse_sd(chains_theta); | ||
EXPECT_TRUE(std::isnan(mcse_mean)); | ||
EXPECT_TRUE(std::isnan(mcse_sd)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include <stan/analyze/mcmc/compute_potential_scale_reduction.hpp> | ||
#include <stan/analyze/mcmc/rhat.hpp> | ||
#include <stan/io/stan_csv_reader.hpp> | ||
#include <gtest/gtest.h> | ||
#include <fstream> | ||
#include <sstream> | ||
|
||
class RhatBasic : public testing::Test { | ||
public: | ||
void SetUp() { | ||
chains_lp.resize(1000, 4); | ||
chains_theta.resize(1000, 4); | ||
draws_theta.resize(4); | ||
draws_lp.resize(4); | ||
sizes.resize(4); | ||
for (size_t i = 0; i < 4; ++i) { | ||
std::stringstream fname; | ||
fname << "src/test/unit/analyze/mcmc/test_csv_files/bern" << (i + 1) | ||
<< ".csv"; | ||
std::ifstream bern_stream(fname.str(), std::ifstream::in); | ||
stan::io::stan_csv bern_csv | ||
= stan::io::stan_csv_reader::parse(bern_stream, &out); | ||
bern_stream.close(); | ||
chains_lp.col(i) = bern_csv.samples.col(0); | ||
chains_theta.col(i) = bern_csv.samples.col(7); | ||
draws_lp[i] = chains_lp.col(i).data(); | ||
draws_theta[i] = chains_theta.col(i).data(); | ||
sizes[i] = 1000; | ||
} | ||
} | ||
|
||
void TearDown() { | ||
} | ||
|
||
std::stringstream out; | ||
Eigen::MatrixXd chains_lp; | ||
Eigen::MatrixXd chains_theta; | ||
std::vector<const double*> draws_theta; | ||
std::vector<const double*> draws_lp; | ||
std::vector<size_t> sizes; | ||
}; | ||
|
||
TEST_F(RhatBasic, test_basic_rhat) { | ||
double rhat_lp_basic_expect = 1.0001296; | ||
double rhat_theta_basic_expect = 1.0029197; | ||
|
||
auto rhat_basic_lp = stan::analyze::rhat(chains_lp); | ||
auto old_rhat_basic_lp | ||
= stan::analyze::compute_potential_scale_reduction(draws_lp, sizes); | ||
|
||
auto rhat_basic_theta = stan::analyze::rhat(chains_theta); | ||
auto old_rhat_basic_theta | ||
= stan::analyze::compute_potential_scale_reduction(draws_theta, sizes); | ||
|
||
EXPECT_NEAR(rhat_lp_basic_expect, rhat_basic_lp, 0.00001); | ||
EXPECT_NEAR(rhat_theta_basic_expect, rhat_basic_theta, 0.00001); | ||
|
||
EXPECT_NEAR(old_rhat_basic_lp, rhat_basic_lp, 0.00001); | ||
EXPECT_NEAR(old_rhat_basic_theta, rhat_basic_theta, 0.00001); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.