Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 68 additions & 45 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
import sys

from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -53,6 +53,8 @@
get_default_varnames,
)

logger = logging.getLogger(__name__)

xla_flags_env = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
Expand Down Expand Up @@ -289,40 +291,46 @@ def _update_coords_and_dims(
dims.update(idata_kwargs.pop("dims"))


@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def _blackjax_inference_loop(
seed,
init_position,
logprob_fn,
draws,
tune,
target_accept,
algorithm=None,
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
):
import blackjax

if algorithm is None:
algorithm_name = adaptation_kwargs.pop("algorithm", "nuts")
if algorithm_name == "nuts":
algorithm = blackjax.nuts
elif algorithm_name == "hmc":
algorithm = blackjax.hmc
else:
raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.")

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
target_acceptance_rate=target_accept,
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
kernel = algorithm(logprob_fn, **tuned_params).step

def inference_loop(rng_key, initial_state):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
def _one_step(state, xs):
_, rng_key = xs
state, info = kernel(rng_key, state)
return state, (state, info)

keys = jax.random.split(rng_key, draws)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
progress_bar = adaptation_kwargs.pop("progress_bar", False)
if progress_bar:
from blackjax.progress_bar import progress_bar_scan

logger.info("Sample with tuned parameters")
one_step = jax.jit(progress_bar_scan(draws)(_one_step))
else:
one_step = jax.jit(_one_step)

return states, infos
keys = jax.random.split(seed, draws)
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

return inference_loop(seed, last_state)
return states, infos


def sample_blackjax_nuts(
Expand All @@ -339,6 +347,7 @@ def sample_blackjax_nuts(
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
idata_kwargs: Optional[Dict[str, Any]] = None,
adaptation_kwargs: Optional[Dict[str, Any]] = None,
postprocessing_chunks=None, # deprecated
) -> az.InferenceData:
"""
Expand Down Expand Up @@ -415,7 +424,7 @@ def sample_blackjax_nuts(
(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
print("Compiling...", file=sys.stdout)
logger.info("Compiling...")

init_params = _get_batched_jittered_initial_points(
model=model,
Expand All @@ -432,36 +441,48 @@ def sample_blackjax_nuts(
seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
)

tic2 = datetime.now()
print("Compilation time = ", tic2 - tic1, file=sys.stdout)

print("Sampling...", file=sys.stdout)
if adaptation_kwargs is None:
adaptation_kwargs = {}

# Adapted from numpyro
if chain_method == "parallel":
map_fn = jax.pmap
if adaptation_kwargs.get("progress_bar", False):
import warnings

warnings.warn(
"BlackJax currently only display progress_bar correctly under "
"`chain_method == 'vectorized'`. Setting `progress_bar=False`."
)
adaptation_kwargs["progress_bar"] = False
elif chain_method == "vectorized":
map_fn = jax.vmap
else:
raise ValueError(
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
)

get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
**adaptation_kwargs,
)

tic2 = datetime.now()
logger.info(f"Compilation time = {tic2 - tic1}")

logger.info("Sampling...")

states, stats = map_fn(get_posterior_samples)(keys, init_params)
raw_mcmc_samples = states.position
potential_energy = states.logdensity
potential_energy = states.logdensity.block_until_ready()
tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
logger.info(f"Sampling time = {tic3 - tic2}")

print("Transforming variables...", file=sys.stdout)
logger.info("Transforming variables...")
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
Expand All @@ -472,7 +493,7 @@ def sample_blackjax_nuts(
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
logger.info(f"Transformation time = {tic4 - tic3}")

if idata_kwargs is None:
idata_kwargs = {}
Expand All @@ -481,15 +502,15 @@ def sample_blackjax_nuts(

if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
logger.info(f"Computing Log Likelihood...")
log_likelihood = _get_log_likelihood(
model,
raw_mcmc_samples,
backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
logger.info(f"Log Likelihood time = {tic6 - tic5}")
else:
log_likelihood = None

Expand Down Expand Up @@ -634,7 +655,7 @@ def sample_numpyro_nuts(
(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
print("Compiling...", file=sys.stdout)
logger.info("Compiling...")

init_params = _get_batched_jittered_initial_points(
model=model,
Expand Down Expand Up @@ -663,9 +684,9 @@ def sample_numpyro_nuts(
)

tic2 = datetime.now()
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
logger.info(f"Compilation time = {tic2 - tic1}")

print("Sampling...", file=sys.stdout)
logger.info("Sampling...")

map_seed = jax.random.PRNGKey(random_seed)
if chains > 1:
Expand All @@ -687,9 +708,9 @@ def sample_numpyro_nuts(
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
logger.info(f"Sampling time = {tic3 - tic2}")

print("Transforming variables...", file=sys.stdout)
logger.info("Transforming variables...")
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
Expand All @@ -700,7 +721,7 @@ def sample_numpyro_nuts(
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
logger.info(f"Transformation time = {tic4 - tic3}")

if idata_kwargs is None:
idata_kwargs = {}
Expand All @@ -709,15 +730,17 @@ def sample_numpyro_nuts(

if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
logger.info(f"Computing Log Likelihood...")
log_likelihood = _get_log_likelihood(
model,
raw_mcmc_samples,
backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
logger.info(
f"Log Likelihood time = {tic6 - tic5}",
)
else:
log_likelihood = None

Expand Down