diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index d9cf85130..5acf37722 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter @@ -209,60 +209,69 @@ def one_step(state, rng_key): def sample_blackjax_nuts( - draws=1000, - tune=1000, - chains=4, - target_accept=0.8, - random_seed: RandomSeed = None, - initvals=None, - model=None, - var_names=None, - keep_untransformed=False, - chain_method="parallel", - postprocessing_backend=None, - idata_kwargs=None, -): + draws: int = 1000, + tune: int = 1000, + chains: int = 4, + target_accept: float = 0.8, + random_seed: Optional[RandomSeed] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + model: Optional[Model] = None, + var_names: Optional[Sequence[str]] = None, + keep_untransformed: bool = False, + chain_method: str = "parallel", + postprocessing_backend: Optional[str] = None, + idata_kwargs: Optional[Dict[str, Any]] = None, +) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by default. + The number of samples to draw. The number of tuned samples are discarded by + default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or - similar during tuning. Tuning samples will be drawn in addition to the number specified in - the ``draws`` argument. + similar during tuning. Tuning samples will be drawn in addition to the number + specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher values like - 0.9 or 0.95 often work better for problematic posteriors. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. + initvals: StartDict or Sequence[Optional[StartDict]], optional + Initial values for random variables provided as a dictionary (or sequence of + dictionaries) mapping the random variable (by name or reference) to desired + starting values. model : Model, optional - Model to sample from. The model needs to have free random variables. When inside a ``with`` model - context, it defaults to that model, otherwise the model must be passed explicitly. - var_names : iterable of str, optional - Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + Model to sample from. The model needs to have free random variables. When inside + a ``with`` model context, it defaults to that model, otherwise the model must be + passed explicitly. + var_names : sequence of str, optional + Names of variables for which to compute the posterior samples. Defaults to all + variables in the posterior. keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and "vectorized". + Specify how samples should be drawn. The choices include "parallel", and + "vectorized". postprocessing_backend : str, optional Specify how postprocessing should be computed. gpu or cpu idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value - for the ``log_likelihood`` key to indicate that the pointwise log likelihood should - not be included in the returned object. Values for ``observed_data``, ``constant_data``, - ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided - in ``idata_kwargs``. + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as + value for the ``log_likelihood`` key to indicate that the pointwise log + likelihood should not be included in the returned object. Values for + ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from + the ``model`` argument if not provided in ``idata_kwargs``. Returns ------- InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and - pointwise log likeihood values (unless skipped with ``idata_kwargs``). + ArviZ ``InferenceData`` object that contains the posterior samples, together + with their respective sample stats and pointwise log likeihood values (unless + skipped with ``idata_kwargs``). """ import blackjax