Analyze a post-training Bernoulli-Bernoulli RBM

This script shows how to analyze the RBM after having trained it.

import matplotlib.pyplot as plt
import torch

device = torch.device("cpu")
dtype = torch.float32

Load the dataset

We suppose the RBM was trained on the dummy.h5 dataset file, with 60% of the train dataset. By default, the dataset splitting is seeded. So just putting the same train_size and test_size ensures having the same split for analysis. This behaviour can be changed by setting a different value to the seed keyword.

from rbms.dataset import load_dataset

train_dataset, test_dataset = load_dataset(
    "dummy.h5", train_size=0.6, test_size=0.4, device=device, dtype=dtype
)
num_visibles = train_dataset.get_num_visibles()

U_data, S_data, V_dataT = torch.linalg.svd(
    train_dataset.data - train_dataset.data.mean(0)
)
proj_data = train_dataset.data @ V_dataT.mT / num_visibles**0.5
proj_data = proj_data.cpu().numpy()

Load the model.

First, we want to know which machines have been saved

from rbms.utils import get_saved_updates

filename = "RBM.h5"
saved_updates = get_saved_updates(filename=filename)
print(f"Saved updates: {saved_updates}")
Saved updates: [    1     2     3     4     5     7     8    10    13    16    20    25
    31    39    49    61    75    94   116   145   180   223   277   344
   427   531   659   819  1017  1262  1567  1946  2416  3000  3725  4625
  5743  7131  8853 10992 13648 16946 21040 26123 32434 40270 50000]

Now we will load the last saved model as well as the permanent chains during training Only the configurations associated to the last saved model have been saved for the permanent chains. We also get access to the hyperparameters of the RBM training as well as the time elapsed during the training.

from rbms.io import load_model

params, permanent_chains, training_time, hyperparameters = load_model(
    filename=filename, index=saved_updates[-1], device=device, dtype=dtype
)

print(f"Training time: {training_time}")
for k in hyperparameters.keys():
    print(f"{k} : {hyperparameters[k]}")
Training time: 66.16361451148987
batch_size : 2000
gibbs_steps : 10
learning_rate : 0.01

To follow the training of the RBM, let’s look at the singular values of the weight matrix

from rbms.utils import get_eigenvalues_history

grad_updates, sing_val = get_eigenvalues_history(filename=filename)

fig, ax = plt.subplots(1, 1)
ax.plot(grad_updates, sing_val)
ax.set_xlabel("Training time (gradient updates)")
ax.set_ylabel("Singular values")
ax.loglog()
fig.show()
plot generated samples

Let’s compare the permanent chains to the dataset distribution. To do so, we project the chains on the first principal components of the dataset.

from rbms.plot import plot_PCA

proj_pc = permanent_chains["visible"] @ V_dataT.mT / num_visibles**0.5

plot_PCA(
    proj_data,
    proj_pc.cpu().numpy(),
    labels=["Dataset", "Permanent chains"],
)
plot generated samples

Sample the RBM

Another interesting thing is to compare generated samples starting from random configurations

from rbms.sampling.gibbs import sample_state

num_samples = 2000
chains = params.init_chains(num_samples=num_samples)
proj_gen_init = chains["visible"] @ V_dataT.mT / num_visibles**0.5
plot_PCA(
    proj_data,
    proj_gen_init.cpu().numpy(),
    labels=["Dataset", "Starting position"],
)
plt.tight_layout()
plot generated samples
/home/runner/work/dsysdml.github.io/dsysdml.github.io/examples/plot_generated_samples.py:100: UserWarning: Tight layout not applied. tight_layout cannot make Axes width small enough to accommodate all Axes decorations
  plt.tight_layout()

We can now sample those chains and compare again the distribution

n_steps = 100
chains = sample_state(gibbs_steps=n_steps, chains=chains, params=params)

proj_gen = chains["visible"] @ V_dataT.mT / num_visibles**0.5
plot_PCA(
    proj_data,
    proj_gen.cpu().numpy(),
    labels=["Dataset", "Generated samples"],
)
plot generated samples

Compute the AIS estimation of the log-likelihood.

For now, we only looked at a qualitative evaluation of the model

from rbms.partition_function.ais import compute_partition_function_ais
from rbms.utils import compute_log_likelihood

log_z_ais = compute_partition_function_ais(num_chains=2000, num_beta=100, params=params)

print(
    compute_log_likelihood(
        train_dataset.data, train_dataset.weights, params=params, log_z=log_z_ais
    )
)
-386.266357421875

Total running time of the script: (0 minutes 10.887 seconds)

Gallery generated by Sphinx-Gallery