Skip to content

Commit c13fe76

Browse files
committed
Merge branch 'main' of github.com:probml/dynamax
2 parents 51dfdb0 + a14929c commit c13fe76

27 files changed

+2361
-2135
lines changed

docs/api.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ Low-level inference
193193
-------------------
194194

195195
.. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_filter
196-
.. autofunction:: dynamax.nonlinear_gaussian_ssm.iterated_extended_kalman_filter
197196
.. autofunction:: dynamax.nonlinear_gaussian_ssm.extended_kalman_smoother
198-
.. autofunction:: dynamax.nonlinear_gaussian_ssm.iterated_extended_kalman_smoother
199197

200198
.. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_filter
201199
.. autofunction:: dynamax.nonlinear_gaussian_ssm.unscented_kalman_smoother
@@ -219,9 +217,7 @@ Low-level inference
219217
-------------------
220218

221219
.. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_filter
222-
.. autofunction:: dynamax.generalized_gaussian_ssm.iterated_conditional_moments_gaussian_filter
223220
.. autofunction:: dynamax.generalized_gaussian_ssm.conditional_moments_gaussian_smoother
224-
.. autofunction:: dynamax.generalized_gaussian_ssm.iterated_conditional_moments_gaussian_smoother
225221

226222
Types
227223
-----

docs/notebooks/generalized_gaussian_ssm/cmgf_logistic_regression_demo.ipynb

Lines changed: 364 additions & 350 deletions
Large diffs are not rendered by default.

docs/notebooks/generalized_gaussian_ssm/cmgf_mlp_classification_demo.ipynb

Lines changed: 28 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
"id": "LOf_GnJovd81"
1818
},
1919
"source": [
20-
"\n",
21-
"\n",
22-
"Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).\n",
23-
"\n",
24-
"\n",
25-
"We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
20+
"This notebook is similar to the previous notebook that demonstrated how to use CMGF for online Bayesian logistic regression. Here, we perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n",
2621
"To do this, we treat the parameters of the model as the unknown hidden states.\n",
2722
"We assume that these are approximately constant over time (we add a small amount of Gaussian drift,\n",
2823
"for numerical stability.)\n",
@@ -120,6 +115,13 @@
120115
"from jax.flatten_util import ravel_pytree"
121116
]
122117
},
118+
{
119+
"cell_type": "markdown",
120+
"metadata": {},
121+
"source": [
122+
"## Helper function to plot the posterior predictive distribution"
123+
]
124+
},
123125
{
124126
"cell_type": "code",
125127
"execution_count": 5,
@@ -499,19 +501,10 @@
499501
"id": "ld7GSZ2PsxLh"
500502
},
501503
"source": [
502-
"Finally, we generate a video of the MLP-Classifier being trained."
503-
]
504-
},
505-
{
506-
"cell_type": "code",
507-
"execution_count": 16,
508-
"metadata": {
509-
"id": "PaL3hY7lhTd4"
510-
},
511-
"outputs": [],
512-
"source": [
513-
"import matplotlib.animation as animation\n",
514-
"from IPython.display import HTML"
504+
"### Animation\n",
505+
"Finally, we generate a video of the MLP-Classifier being trained.\n",
506+
"\n",
507+
"Note: This code is commented out by default since it takes a while time to run."
515508
]
516509
},
517510
{
@@ -522,47 +515,22 @@
522515
},
523516
"outputs": [],
524517
"source": [
525-
"def animate(i):\n",
526-
" ax.cla()\n",
527-
" w_curr = w_means[i]\n",
528-
" Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
529-
" title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
530-
" plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
531-
" return ax"
532-
]
533-
},
534-
{
535-
"cell_type": "code",
536-
"execution_count": 18,
537-
"metadata": {
538-
"colab": {
539-
"base_uri": "https://localhost:8080/",
540-
"height": 319
541-
},
542-
"id": "tMbQ8HU9iUr7",
543-
"outputId": "c5c55414-8720-432b-e3f5-a20db8c5ba43"
544-
},
545-
"outputs": [],
546-
"source": [
547-
"#fig, ax = plt.subplots(figsize=(6, 5))\n",
548-
"#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
549-
"#anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)"
550-
]
551-
},
552-
{
553-
"cell_type": "code",
554-
"execution_count": 19,
555-
"metadata": {
556-
"colab": {
557-
"base_uri": "https://localhost:8080/",
558-
"height": 381
559-
},
560-
"id": "TAMU4Qc6rYM5",
561-
"outputId": "799e598f-4c85-4eb6-b3c0-6a84ae327ea6"
562-
},
563-
"outputs": [],
564-
"source": [
565-
"#HTML(anim.to_html5_video())"
518+
"# import matplotlib.animation as animation\n",
519+
"# from IPython.display import HTML\n",
520+
"\n",
521+
"# def animate(i):\n",
522+
"# ax.cla()\n",
523+
"# w_curr = w_means[i]\n",
524+
"# Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n",
525+
"# title = f'CMGF-EKF-MLP ({i+1}/500)'\n",
526+
"# plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n",
527+
"# return ax\n",
528+
"\n",
529+
"# fig, ax = plt.subplots(figsize=(6, 5))\n",
530+
"# anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n",
531+
"# anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)\n",
532+
"\n",
533+
"# HTML(anim.to_html5_video())"
566534
]
567535
}
568536
],

0 commit comments

Comments
 (0)