|
17 | 17 | "id": "LOf_GnJovd81" |
18 | 18 | }, |
19 | 19 | "source": [ |
20 | | - "\n", |
21 | | - "\n", |
22 | | - "Online training of an multilayer perceptron (MLP) classifier using conditional moments Gaussian filter (CMGF).\n", |
23 | | - "\n", |
24 | | - "\n", |
25 | | - "We perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n", |
| 20 | + "This notebook is similar to the previous notebook that demonstrated how to use CMGF for online Bayesian logistic regression. Here, we perform sequential (recursive) Bayesian inference for the parameters of a binary MLP classifier.\n", |
26 | 21 | "To do this, we treat the parameters of the model as the unknown hidden states.\n", |
27 | 22 | "We assume that these are approximately constant over time (we add a small amount of Gaussian drift,\n", |
28 | 23 | "for numerical stability.)\n", |
|
120 | 115 | "from jax.flatten_util import ravel_pytree" |
121 | 116 | ] |
122 | 117 | }, |
| 118 | + { |
| 119 | + "cell_type": "markdown", |
| 120 | + "metadata": {}, |
| 121 | + "source": [ |
| 122 | + "## Helper function to plot the posterior predictive distribution" |
| 123 | + ] |
| 124 | + }, |
123 | 125 | { |
124 | 126 | "cell_type": "code", |
125 | 127 | "execution_count": 5, |
|
499 | 501 | "id": "ld7GSZ2PsxLh" |
500 | 502 | }, |
501 | 503 | "source": [ |
502 | | - "Finally, we generate a video of the MLP-Classifier being trained." |
503 | | - ] |
504 | | - }, |
505 | | - { |
506 | | - "cell_type": "code", |
507 | | - "execution_count": 16, |
508 | | - "metadata": { |
509 | | - "id": "PaL3hY7lhTd4" |
510 | | - }, |
511 | | - "outputs": [], |
512 | | - "source": [ |
513 | | - "import matplotlib.animation as animation\n", |
514 | | - "from IPython.display import HTML" |
| 504 | + "### Animation\n", |
| 505 | + "Finally, we generate a video of the MLP-Classifier being trained.\n", |
| 506 | + "\n", |
| 507 | + "Note: This code is commented out by default since it takes a while time to run." |
515 | 508 | ] |
516 | 509 | }, |
517 | 510 | { |
|
522 | 515 | }, |
523 | 516 | "outputs": [], |
524 | 517 | "source": [ |
525 | | - "def animate(i):\n", |
526 | | - " ax.cla()\n", |
527 | | - " w_curr = w_means[i]\n", |
528 | | - " Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n", |
529 | | - " title = f'CMGF-EKF-MLP ({i+1}/500)'\n", |
530 | | - " plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n", |
531 | | - " return ax" |
532 | | - ] |
533 | | - }, |
534 | | - { |
535 | | - "cell_type": "code", |
536 | | - "execution_count": 18, |
537 | | - "metadata": { |
538 | | - "colab": { |
539 | | - "base_uri": "https://localhost:8080/", |
540 | | - "height": 319 |
541 | | - }, |
542 | | - "id": "tMbQ8HU9iUr7", |
543 | | - "outputId": "c5c55414-8720-432b-e3f5-a20db8c5ba43" |
544 | | - }, |
545 | | - "outputs": [], |
546 | | - "source": [ |
547 | | - "#fig, ax = plt.subplots(figsize=(6, 5))\n", |
548 | | - "#anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n", |
549 | | - "#anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)" |
550 | | - ] |
551 | | - }, |
552 | | - { |
553 | | - "cell_type": "code", |
554 | | - "execution_count": 19, |
555 | | - "metadata": { |
556 | | - "colab": { |
557 | | - "base_uri": "https://localhost:8080/", |
558 | | - "height": 381 |
559 | | - }, |
560 | | - "id": "TAMU4Qc6rYM5", |
561 | | - "outputId": "799e598f-4c85-4eb6-b3c0-6a84ae327ea6" |
562 | | - }, |
563 | | - "outputs": [], |
564 | | - "source": [ |
565 | | - "#HTML(anim.to_html5_video())" |
| 518 | + "# import matplotlib.animation as animation\n", |
| 519 | + "# from IPython.display import HTML\n", |
| 520 | + "\n", |
| 521 | + "# def animate(i):\n", |
| 522 | + "# ax.cla()\n", |
| 523 | + "# w_curr = w_means[i]\n", |
| 524 | + "# Zi = posterior_predictive_grid(input_grid, w_means[i], sigmoid_fn)\n", |
| 525 | + "# title = f'CMGF-EKF-MLP ({i+1}/500)'\n", |
| 526 | + "# plot_posterior_predictive(ax, input[:i+1], output[:i+1], title, input_grid, Zi) \n", |
| 527 | + "# return ax\n", |
| 528 | + "\n", |
| 529 | + "# fig, ax = plt.subplots(figsize=(6, 5))\n", |
| 530 | + "# anim = animation.FuncAnimation(fig, animate, frames=500, interval=50)\n", |
| 531 | + "# anim.save(\"cmgf_mlp_classifier.mp4\", dpi=200, bitrate=-1, fps=24)\n", |
| 532 | + "\n", |
| 533 | + "# HTML(anim.to_html5_video())" |
566 | 534 | ] |
567 | 535 | } |
568 | 536 | ], |
|
0 commit comments