From ededce60bd6f8ac2e14a57ab5452da4067b6eba8 Mon Sep 17 00:00:00 2001 From: Josh Cook Date: Sun, 31 Jul 2022 07:44:43 -0400 Subject: [PATCH 1/4] refactor: typehints for arguments and return of `sample_blackjax_nuts` --- pymc/sampling_jax.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index d9cf85130..9496c86f6 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union from pymc.initial_point import StartDict from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter @@ -209,19 +209,19 @@ def one_step(state, rng_key): def sample_blackjax_nuts( - draws=1000, - tune=1000, - chains=4, - target_accept=0.8, - random_seed: RandomSeed = None, - initvals=None, - model=None, - var_names=None, - keep_untransformed=False, - chain_method="parallel", - postprocessing_backend=None, - idata_kwargs=None, -): + draws: int = 1000, + tune: int = 1000, + chains: int = 4, + target_accept: float = 0.8, + random_seed: Optional[RandomSeed] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + model: Optional[Model] = None, + var_names: Optional[Iterable[str]] = None, + keep_untransformed: bool = False, + chain_method: str = "parallel", + postprocessing_backend: Optional[str] = None, + idata_kwargs: Optional[Dict[str, Any]] = None, +) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. From dcb9258e005fa48fc5e12c4012ec6b31c6a0a118 Mon Sep 17 00:00:00 2001 From: Josh Cook Date: Sun, 31 Jul 2022 07:48:11 -0400 Subject: [PATCH 2/4] doc: style and add `initvals` to `sample_blackjax_nuts` docstring --- pymc/sampling_jax.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 9496c86f6..19c3c7b46 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -228,41 +228,50 @@ def sample_blackjax_nuts( Parameters ---------- draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by default. + The number of samples to draw. The number of tuned samples are discarded by + default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or - similar during tuning. Tuning samples will be drawn in addition to the number specified in - the ``draws`` argument. + similar during tuning. Tuning samples will be drawn in addition to the number + specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher values like - 0.9 or 0.95 often work better for problematic posteriors. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. + initvals: StartDict or Sequence[Optional[StartDict]], optional + Initial values for random variables provided as a dictionary (or sequence of + dictionaries) mapping the random variable (by name or reference) to desired + starting values. model : Model, optional - Model to sample from. The model needs to have free random variables. When inside a ``with`` model - context, it defaults to that model, otherwise the model must be passed explicitly. + Model to sample from. The model needs to have free random variables. When inside + a ``with`` model context, it defaults to that model, otherwise the model must be + passed explicitly. var_names : iterable of str, optional - Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior + Names of variables for which to compute the posterior samples. Defaults to all + variables in the posterior keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and "vectorized". + Specify how samples should be drawn. The choices include "parallel", and + "vectorized". postprocessing_backend : str, optional Specify how postprocessing should be computed. gpu or cpu idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value - for the ``log_likelihood`` key to indicate that the pointwise log likelihood should - not be included in the returned object. Values for ``observed_data``, ``constant_data``, - ``coords``, and ``dims`` are inferred from the ``model`` argument if not provided - in ``idata_kwargs``. + Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as + value for the ``log_likelihood`` key to indicate that the pointwise log + likelihood should not be included in the returned object. Values for + ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from + the ``model`` argument if not provided in ``idata_kwargs``. Returns ------- InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and - pointwise log likeihood values (unless skipped with ``idata_kwargs``). + ArviZ ``InferenceData`` object that contains the posterior samples, together + with their respective sample stats and pointwise log likeihood values (unless + skipped with ``idata_kwargs``). """ import blackjax From c5e317e28e4f0626acb110bc14fcc48966a88664 Mon Sep 17 00:00:00 2001 From: Josh Cook <39419448+jhrcook@users.noreply.github.com> Date: Sun, 31 Jul 2022 10:45:08 -0400 Subject: [PATCH 3/4] refactor: change `var_names` from `Iterable` to `Sequence` typehint Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/sampling_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 19c3c7b46..d18bb0120 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -216,7 +216,7 @@ def sample_blackjax_nuts( random_seed: Optional[RandomSeed] = None, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, model: Optional[Model] = None, - var_names: Optional[Iterable[str]] = None, + var_names: Optional[Sequence[str]] = None, keep_untransformed: bool = False, chain_method: str = "parallel", postprocessing_backend: Optional[str] = None, From 7209c09bd8a84c227dbb9e2e8040ae5d96eba568 Mon Sep 17 00:00:00 2001 From: Josh Cook Date: Sun, 31 Jul 2022 11:19:27 -0400 Subject: [PATCH 4/4] style: remove Iterable import --- pymc/sampling_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index d18bb0120..5acf37722 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,7 +4,7 @@ import warnings from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter @@ -249,9 +249,9 @@ def sample_blackjax_nuts( Model to sample from. The model needs to have free random variables. When inside a ``with`` model context, it defaults to that model, otherwise the model must be passed explicitly. - var_names : iterable of str, optional + var_names : sequence of str, optional Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior + variables in the posterior. keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel"