.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_generated_samples.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_generated_samples.py: Analyze a post-training Bernoulli-Bernoulli RBM =========================== This script shows how to analyze the RBM after having trained it. .. GENERATED FROM PYTHON SOURCE LINES 7-15 .. code-block:: Python import matplotlib.pyplot as plt import torch device = torch.device("cpu") dtype = torch.float32 .. GENERATED FROM PYTHON SOURCE LINES 16-22 Load the dataset ------------------------ We suppose the RBM was trained on the `dummy.h5` dataset file, with 60% of the train dataset. By default, the dataset splitting is seeded. So just putting the same train_size and test_size ensures having the same split for analysis. This behaviour can be changed by setting a different value to the `seed` keyword. .. GENERATED FROM PYTHON SOURCE LINES 22-35 .. code-block:: Python from rbms.dataset import load_dataset train_dataset, test_dataset = load_dataset( "dummy.h5", train_size=0.6, test_size=0.4, device=device, dtype=dtype ) num_visibles = train_dataset.get_num_visibles() U_data, S_data, V_dataT = torch.linalg.svd( train_dataset.data - train_dataset.data.mean(0) ) proj_data = train_dataset.data @ V_dataT.mT / num_visibles**0.5 proj_data = proj_data.cpu().numpy() .. GENERATED FROM PYTHON SOURCE LINES 36-39 Load the model. ------------------------ First, we want to know which machines have been saved .. GENERATED FROM PYTHON SOURCE LINES 39-45 .. code-block:: Python from rbms.utils import get_saved_updates filename = "RBM.h5" saved_updates = get_saved_updates(filename=filename) print(f"Saved updates: {saved_updates}") .. rst-class:: sphx-glr-script-out .. code-block:: none Saved updates: [ 1 2 3 4 5 7 8 10 13 16 20 25 31 39 49 61 75 94 116 145 180 223 277 344 427 531 659 765 819 849 990 1017 1024 1262 1567 1568 1946 2118 2391 2416 2655 3000 3725 3799 4625 5743 7033 7131 8853 10992 13648 16946 18038 21040 26123 31967 32434 40270 42529 50000] .. GENERATED FROM PYTHON SOURCE LINES 46-49 Now we will load the last saved model as well as the permanent chains during training Only the configurations associated to the last saved model have been saved for the permanent chains. We also get access to the hyperparameters of the RBM training as well as the time elapsed during the training. .. GENERATED FROM PYTHON SOURCE LINES 49-59 .. code-block:: Python from rbms.io import load_model params, permanent_chains, training_time, hyperparameters = load_model( filename=filename, index=saved_updates[-1], device=device, dtype=dtype ) print(f"Training time: {training_time}") for k in hyperparameters.keys(): print(f"{k} : {hyperparameters[k]}") .. rst-class:: sphx-glr-script-out .. code-block:: none Training time: 4915.106016874313 batch_size : 2000 gibbs_steps : 5 learning_rate : 0.01 .. GENERATED FROM PYTHON SOURCE LINES 60-61 To follow the training of the RBM, let's look at the singular values of the weight matrix .. GENERATED FROM PYTHON SOURCE LINES 61-72 .. code-block:: Python from rbms.utils import get_eigenvalues_history grad_updates, sing_val = get_eigenvalues_history(filename=filename) fig, ax = plt.subplots(1, 1) ax.plot(grad_updates, sing_val) ax.set_xlabel("Training time (gradient updates)") ax.set_ylabel("Singular values") ax.loglog() fig.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_generated_samples_001.png :alt: plot generated samples :srcset: /auto_examples/images/sphx_glr_plot_generated_samples_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 73-75 Let's compare the permanent chains to the dataset distribution. To do so, we project the chains on the first principal components of the dataset. .. GENERATED FROM PYTHON SOURCE LINES 75-85 .. code-block:: Python from rbms.plot import plot_PCA proj_pc = permanent_chains["visible"] @ V_dataT.mT / num_visibles**0.5 plot_PCA( proj_data, proj_pc.cpu().numpy(), labels=["Dataset", "Permanent chains"], ) .. image-sg:: /auto_examples/images/sphx_glr_plot_generated_samples_002.png :alt: plot generated samples :srcset: /auto_examples/images/sphx_glr_plot_generated_samples_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 86-90 Sample the RBM ------------------------------ Another interesting thing is to compare generated samples starting from random configurations .. GENERATED FROM PYTHON SOURCE LINES 90-102 .. code-block:: Python from rbms.sampling.gibbs import sample_state num_samples = 2000 chains = params.init_chains(num_samples=num_samples) proj_gen_init = chains["visible"] @ V_dataT.mT / num_visibles**0.5 plot_PCA( proj_data, proj_gen_init.cpu().numpy(), labels=["Dataset", "Starting position"], ) plt.tight_layout() .. image-sg:: /auto_examples/images/sphx_glr_plot_generated_samples_003.png :alt: plot generated samples :srcset: /auto_examples/images/sphx_glr_plot_generated_samples_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/nbereux/work/rbms/examples/plot_generated_samples.py:100: UserWarning: Tight layout not applied. tight_layout cannot make Axes width small enough to accommodate all Axes decorations plt.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 103-104 We can now sample those chains and compare again the distribution .. GENERATED FROM PYTHON SOURCE LINES 104-114 .. code-block:: Python n_steps = 100 chains = sample_state(gibbs_steps=n_steps, chains=chains, params=params) proj_gen = chains["visible"] @ V_dataT.mT / num_visibles**0.5 plot_PCA( proj_data, proj_gen.cpu().numpy(), labels=["Dataset", "Generated samples"], ) .. image-sg:: /auto_examples/images/sphx_glr_plot_generated_samples_004.png :alt: plot generated samples :srcset: /auto_examples/images/sphx_glr_plot_generated_samples_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 115-118 Compute the AIS estimation of the log-likelihood. ------------------------------ For now, we only looked at a qualitative evaluation of the model .. GENERATED FROM PYTHON SOURCE LINES 118-128 .. code-block:: Python from rbms.partition_function.ais import compute_partition_function_ais from rbms.utils import compute_log_likelihood log_z_ais = compute_partition_function_ais(num_chains=2000, num_beta=100, params=params) print( compute_log_likelihood( train_dataset.data, train_dataset.weights, params=params, log_z=log_z_ais ) ) .. rst-class:: sphx-glr-script-out .. code-block:: none -387.64599609375 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.271 seconds) .. _sphx_glr_download_auto_examples_plot_generated_samples.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_generated_samples.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_generated_samples.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_generated_samples.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_