Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: multichain does not use any init file besides the first #1191

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ int command(int argc, const char *argv[]) {
}

std::vector<std::shared_ptr<stan::io::var_context>> init_contexts
= get_vec_var_context(init, num_chains);
= get_vec_var_context(init, num_chains, id);
std::vector<std::string> model_compile_info = model.model_compile_info();

for (int i = 0; i < num_chains; ++i) {
Expand Down Expand Up @@ -510,7 +510,7 @@ int command(int argc, const char *argv[]) {
dynamic_cast<string_argument *>(algo->arg("hmc")->arg("metric_file"))
->value());
context_vector metric_contexts
= get_vec_var_context(metric_filename, num_chains);
= get_vec_var_context(metric_filename, num_chains, id);
categorical_argument *adapt
= dynamic_cast<categorical_argument *>(sample_arg->arg("adapt"));
categorical_argument *hmc
Expand Down
59 changes: 30 additions & 29 deletions src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,29 @@ inline shared_context_ptr get_var_context(const std::string &file) {
return std::make_shared<stan::io::dump>(var_context);
}

std::vector<std::string> make_filenames(const std::string &filename,
const std::string &tag,
const std::string &type,
unsigned int num_chains,
unsigned int id) {
std::vector<std::string> names(num_chains);
auto base_sfx = get_basename_suffix(filename);
if (base_sfx.second.empty()) {
base_sfx.second = type;
}
auto name_iterator = [num_chains, id](auto i) {
if (num_chains == 1) {
return std::string("");
} else {
return std::string("_" + std::to_string(i + id));
}
};
for (int i = 0; i < num_chains; ++i) {
names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second;
}
return names;
}

using context_vector = std::vector<shared_context_ptr>;
/**
* Make a vector of shared pointers to contexts.
Expand All @@ -201,7 +224,8 @@ using context_vector = std::vector<shared_context_ptr>;
* @param num_chains The number of chains to run
* @return a std vector of shared pointers to var contexts
*/
context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
context_vector get_vec_var_context(const std::string &file, size_t num_chains,
unsigned int id) {
using stan::io::var_context;
if (num_chains == 1) {
return context_vector(1, get_var_context(file));
Expand Down Expand Up @@ -249,8 +273,9 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
"\tConsider saving your data in JSON format instead."
<< std::endl;
}
std::string file_1
= std::string(file_name + "_" + std::to_string(1) + file_ending);

auto filenames = make_filenames(file_name, "", file_ending, num_chains, id);
auto &file_1 = filenames[0];
std::fstream stream_1(file_1.c_str(), std::fstream::in);
// if file_1 exists we'll assume num_chains of these files exist
if (stream_1.rdstate() & std::ifstream::failbit) {
Expand All @@ -274,9 +299,8 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
ret.reserve(num_chains);
ret.push_back(make_context(file_1, stream_1, file_ending));
for (size_t i = 1; i < num_chains; ++i) {
std::string file_i
= std::string(file_name + "_" + std::to_string(i) + file_ending);
std::fstream stream_i(file_1.c_str(), std::fstream::in);
auto &file_i = filenames[i];
std::fstream stream_i(file_i.c_str(), std::fstream::in);
// If any stream fails here something went wrong with file names
if (stream_i.rdstate() & std::ifstream::failbit) {
std::string file_name_err = std::string(
Expand Down Expand Up @@ -737,29 +761,6 @@ void check_file_config(argument_parser &parser) {
}
}

std::vector<std::string> make_filenames(const std::string &filename,
const std::string &tag,
const std::string &type,
unsigned int num_chains,
unsigned int id) {
std::vector<std::string> names(num_chains);
auto base_sfx = get_basename_suffix(filename);
if (base_sfx.second.empty()) {
base_sfx.second = type;
}
auto name_iterator = [num_chains, id](auto i) {
if (num_chains == 1) {
return std::string("");
} else {
return std::string("_" + std::to_string(i + id));
}
};
for (int i = 0; i < num_chains; ++i) {
names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second;
}
return names;
}

void init_callbacks(
argument_parser &parser,
std::vector<stan::callbacks::unique_stream_writer<std::ofstream>>
Expand Down
40 changes: 40 additions & 0 deletions src/test/interface/multi_chain_init_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CmdStan : public testing::Test {
init_data = {"src", "test", "test-models", "bern_init.json"};
init2_data = {"src", "test", "test-models", "bern_init2.json"};
init3_data = {"src", "test", "test-models", "bern_init2.R"};
init_bad_data = {"src", "test", "test-models", "bern_init_bad.json"};
dev_null_path = {"/dev", "null"};
}
std::vector<std::string> bern_model;
Expand All @@ -26,6 +27,7 @@ class CmdStan : public testing::Test {
std::vector<std::string> init_data;
std::vector<std::string> init2_data;
std::vector<std::string> init3_data;
std::vector<std::string> init_bad_data;
};

TEST_F(CmdStan, multi_chain_single_init_file_good) {
Expand All @@ -52,6 +54,44 @@ TEST_F(CmdStan, multi_chain_multi_init_file_good) {
ASSERT_FALSE(out.hasError);
}

TEST_F(CmdStan, multi_chain_multi_init_file_id_good) {
std::stringstream ss;
ss << convert_model_path(bern_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " init=" << convert_model_path(init2_data) << " id=2"
<< " method=sample num_chains=2";
std::string cmd = ss.str();
run_command_output out = run_command(cmd);
ASSERT_FALSE(out.hasError) << out.output;
}

TEST_F(CmdStan, multi_chain_multi_init_file_id_bad) {
// this will start by requesting ..._4.json, which doesn't exist
std::stringstream ss;
ss << convert_model_path(bern_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " init=" << convert_model_path(init2_data) << " id=4"
<< " method=sample num_chains=3";
std::string cmd = ss.str();
run_command_output out = run_command(cmd);
ASSERT_TRUE(out.hasError);
}

TEST_F(CmdStan, multi_chain_multi_init_file_actually_used) {
// the second chain has a bad init value
std::stringstream ss;
ss << convert_model_path(bern_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " init=" << convert_model_path(init_bad_data)
<< " method=sample num_chains=2";
std::string cmd = ss.str();
run_command_output out = run_command(cmd);
ASSERT_TRUE(out.hasError) << out.output;
}

TEST_F(CmdStan, multi_chain_multi_init_file_R) {
std::stringstream ss;
ss << convert_model_path(bern_model)
Expand Down
3 changes: 3 additions & 0 deletions src/test/test-models/bern_init_bad_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"theta" : 0.1
}
3 changes: 3 additions & 0 deletions src/test/test-models/bern_init_bad_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"theta" : 3.0
}