Skip to content
This repository has been archived by the owner on Mar 19, 2021. It is now read-only.

Commit

Permalink
Merge pull request #498 from stan-dev/doc/stan_threads
Browse files Browse the repository at this point in the history
DOC: Instructions for usage of STAN_THREADS
  • Loading branch information
ahartikainen authored Aug 16, 2018
2 parents 857e8d2 + 88ed5f4 commit f4833a3
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 1 deletion.
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Documentation
api
conversion
logging
threading_support
windows

Stan documentation
Expand Down Expand Up @@ -56,4 +57,3 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

107 changes: 107 additions & 0 deletions doc/threading_support.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
.. _threading_support:

.. currentmodule:: pystan

===================================
Threading Support with Pystan 2.18+
===================================

Notice! This is an experimental feature and is not tested or supported officially with PyStan 2.
Official multithreading support will land with PyStan 3.

By default, `stan-math` is not thread safe. Stan 2.18+ has ability to switch on
threading support with compile time arguments.

See https://github.com/stan-dev/math/wiki/Threading-Support

Due to use of `multiprocessing` to parallelize chains, user needs to be aware of the cpu usage.
This means that each chain will use `STAN_NUM_THREADS` cpu cores and this can have an affect on performance.

Windows
=======

These instructions are invalid on Windows with MingW-W64 compiler and should not be used.
Usage will crash the current Python session, which means that no sampling can be done.

see https://github.com/Alexpux/MINGW-packages/issues/2519 and https://sourceforge.net/p/mingw-w64/bugs/445/

Example
=======

.. code-block:: python
import pystan
import os
import sys
# set environmental variable STAN_NUM_THREADS
# Use 4 cores per chain
os.environ['STAN_NUM_THREADS'] = "4"
# Example model
# see http://discourse.mc-stan.org/t/cant-make-cmdstan-2-18-in-windows/5088/18
stan_code = """
functions {
vector bl_glm(vector mu_sigma, vector beta,
real[] x, int[] y) {
vector[2] mu = mu_sigma[1:2];
vector[2] sigma = mu_sigma[3:4];
real lp = normal_lpdf(beta | mu, sigma);
real ll = bernoulli_logit_lpmf(y | beta[1] + beta[2] * to_vector(x));
return [lp + ll]';
}
}
data {
int<lower = 0> K;
int<lower = 0> N;
vector[N] x;
int<lower = 0, upper = 1> y[N];
}
transformed data {
int<lower = 0> J = N / K;
real x_r[K, J];
int<lower = 0, upper = 1> x_i[K, J];
{
int pos = 1;
for (k in 1:K) {
int end = pos + J - 1;
x_r[k] = to_array_1d(x[pos:end]);
x_i[k] = y[pos:end];
pos += J;
}
}
}
parameters {
vector[2] beta[K];
vector[2] mu;
vector<lower=0>[2] sigma;
}
model {
mu ~ normal(0, 2);
sigma ~ normal(0, 2);
target += sum(map_rect(bl_glm, append_row(mu, sigma),
beta, x_r, x_i));
}
"""
stan_data = dict(
K = 4,
N = 12,
x = [1.204, -0.573, -1.35, -1.157,
-1.29, 0.515, 1.496, 0.918,
0.517, 1.092, -0.485, -2.157],
y = [1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1]
)
extra_compile_args = ['-pthread', '-DSTAN_THREADS']
stan_model = pystan.StanModel(
model_code=stan_code,
extra_compile_args=extra_compile_args
)
# use the default 4 chains == 4 parallel process
# used cores = min(cpu_cores, 4*STAN_NUM_THREADS)
fit = stan_model.sampling(data=stan_data, n_jobs=4)
print(fit)
79 changes: 79 additions & 0 deletions pystan/tests/test_extra_compile_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import distutils.errors
import sys
import os
import unittest

import pystan
Expand Down Expand Up @@ -39,3 +40,81 @@ def test_extra_compile_args_failure(self):
with assertRaisesRegex(distutils.errors.CompileError, 'failed with exit status'):
pystan.StanModel(model_code=model_code, model_name="normal1",
extra_compile_args=extra_compile_args)

def test_threading_support(self):
# Dont test with Windows
if sys.platform.startswith("win"):
return
# Set up environmental variable
os.environ['STAN_NUM_THREADS'] = "2"
# Enable threading
extra_compile_args = ['-pthread', '-DSTAN_THREADS']
stan_code = """
functions {
vector bl_glm(vector mu_sigma, vector beta,
real[] x, int[] y) {
vector[2] mu = mu_sigma[1:2];
vector[2] sigma = mu_sigma[3:4];
real lp = normal_lpdf(beta | mu, sigma);
real ll = bernoulli_logit_lpmf(y | beta[1] + beta[2] * to_vector(x));
return [lp + ll]';
}
}
data {
int<lower = 0> K;
int<lower = 0> N;
vector[N] x;
int<lower = 0, upper = 1> y[N];
}
transformed data {
int<lower = 0> J = N / K;
real x_r[K, J];
int<lower = 0, upper = 1> x_i[K, J];
{
int pos = 1;
for (k in 1:K) {
int end = pos + J - 1;
x_r[k] = to_array_1d(x[pos:end]);
x_i[k] = y[pos:end];
pos += J;
}
}
}
parameters {
vector[2] beta[K];
vector[2] mu;
vector<lower=0>[2] sigma;
}
model {
mu ~ normal(0, 2);
sigma ~ normal(0, 2);
target += sum(map_rect(bl_glm, append_row(mu, sigma),
beta, x_r, x_i));
}
"""
stan_data = dict(
K = 4,
N = 12,
x = [1.204, -0.573, -1.35, -1.157,
-1.29, 0.515, 1.496, 0.918,
0.517, 1.092, -0.485, -2.157],
y = [1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1]
)
stan_model = pystan.StanModel(
model_code=stan_code,
extra_compile_args=extra_compile_args
)
fit = stan_model.sampling(data=stan_data, chains=2, n_jobs=1)
self.assertIsNotNone(fit)
fit2 = stan_model.sampling(data=stan_data, chains=2, n_jobs=2)
self.assertIsNotNone(fit2)
draw = fit.extract(pars=fit.model_pars+['lp__'], permuted=False)
lp = {key : values[-1, 0] for key, values in draw.items() if key == 'lp__'}['lp__']
draw = {key : values[-1, 0] for key, values in draw.items() if key != 'lp__'}
draw = fit.unconstrain_pars(draw)
self.assertEqual(fit.log_prob(draw), lp)
draw2 = fit2.extract(pars=fit2.model_pars+['lp__'], permuted=False)
lp2 = {key : values[-1, 0] for key, values in draw2.items() if key == 'lp__'}['lp__']
draw2 = {key : values[-1, 0] for key, values in draw2.items() if key != 'lp__'}
draw2 = fit2.unconstrain_pars(draw2)
self.assertEqual(fit2.log_prob(draw2), lp2)

0 comments on commit f4833a3

Please sign in to comment.