Reproducing Kingma & Welling (2013), Fig. 2 (MNIST VAE)#

This example reproduces Figure 2 from the paper on Auto-Encoding Variational Bayes (Kingma & Welling (2013)). The experiment trains Variational Autoencoders (VAEs) with different latent dimensionalities \(N_z\) on MNIST and tracks the Evidence Lower Bound (ELBO) as a function of the number of training samples processed.

The setup follows the configuration in the paper:

  • one-hidden-layer MLP encoder and decoder with Tanh activations;

  • hidden size \(H = 500\);

  • mini-batch size \(M = 100\);

  • learning rate selected from \(\{0.01, 0.02, 0.1\}\) (here: 0.02);

  • small weight decay (approximate \(\mathcal{N}(0, I)\) prior on weights);

  • \(L = 1\) Monte Carlo sample for the latent variable estimator.

Configuration#

We begin with imports, global settings, and the reproducibility utilities. This includes the list of latent dimensions used for \(N_z\), hyperparameters, and device configuration.

 1"""
 2Reproduce the MNIST VAE experiment from Kingma & Welling (2013), Fig. 2.
 3
 4We train VAEs with different latent dimensionalities N_z and track the ELBO
 5as a function of the number of training samples seen. 
 6The setup follows the original paper:
 7
 8- One hidden layer MLPs with Tanh activations in encoder/decoder
 9- Hidden size H = 500
10- Mini-batch size M = 100
11- Learning rate selected from {0.01, 0.02, 0.1} (we use 0.02 here)
12- L = 1 Monte Carlo sample for the stochastic latent variable
13"""
14
15from __future__ import annotations
16
17import random
18import time
19from pathlib import Path
20from typing import Dict, List, Tuple
21
22import matplotlib.pyplot as plt
23import numpy as np
24import torch
25import torch.nn as nn
26from matplotlib.ticker import (
27    LogFormatterMathtext,
28    LogLocator,
29    NullFormatter,
30)
31from torch.utils.data import DataLoader
32from torchvision import datasets, transforms
33
34from pyautoencoder.variational import VAE
35
36
37# ---------------- Configuration ---------------- #
38LATENT_DIMS: List[int] = [3, 5, 10, 20, 200]  # N_z values (paper)
39HIDDEN_SIZE: int = 500                        # hidden layer size (encoder/decoder, MNIST)
40BATCH_SIZE: int = 100                         # M = 100 (paper)
41LEARNING_RATE: float = 0.02                   # chosen from {0.01, 0.02, 0.1} (paper)
42MC_SAMPLES: int = 1                           # L = 1 (paper)
43TARGET_TRAIN_SAMPLES: int = int(1e7)          # stop after this many training samples
44EVAL_EVERY_SAMPLES: int = int(1e5)            # evaluate and log every this many samples
45USE_STOCHASTIC_BINARIZATION: bool = False     # binarize x ~ Bernoulli(p = p_x) if True
46
47DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48SEED: int = 1926
49
50
51# ---------------- Utilities ------------------- #
52def set_seed(seed: int) -> None:
53    random.seed(seed)
54    np.random.seed(seed)
55    torch.manual_seed(seed)
56    if torch.cuda.is_available():
57        torch.cuda.manual_seed(seed)
58
59
60def init_weights_small_normal(module: nn.Module) -> None:
61    if isinstance(module, nn.Linear):
62        nn.init.normal_(module.weight, mean=0.0, std=0.01)
63        if module.bias is not None:
64            nn.init.zeros_(module.bias)

Key parameters:

  • LATENT_DIMS – list of latent sizes \(N_z\) evaluated.

  • HIDDEN_SIZE – hidden layer size in encoder and decoder.

  • MC_SAMPLES – number of Monte Carlo samples (\(L=1\) in the paper).

  • set_seed() – deterministic seeding for reproducibility.

Data Loading#

The MNIST dataset is loaded through torchvision.datasets. Optionally, inputs may be stochastically binarized as in the original AEVB experiments.

 1def make_dataloaders(
 2    batch_size: int,
 3    use_stochastic_binarization: bool = USE_STOCHASTIC_BINARIZATION,
 4) -> Tuple[DataLoader, DataLoader]:
 5
 6    tfms: List[transforms.Compose | transforms.ToTensor] = [transforms.ToTensor()]
 7
 8    if use_stochastic_binarization:
 9        class BernoulliBinarize:
10            def __call__(self, x: torch.Tensor) -> torch.Tensor:
11                # Sample x ~ Bernoulli(p = x); assumes x in [0, 1].
12                return torch.bernoulli(x)
13
14        tfms.append(BernoulliBinarize()) # type: ignore
15
16    transform = transforms.Compose(tfms)
17
18    train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
19    test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=transform)
20
21    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
22    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, drop_last=False)
23    return train_loader, test_loader

The make_dataloaders() function returns training and test dataloaders with appropriate preprocessing and optional binarization.

Model Definition#

Each VAE consists of:

  • a single-hidden-layer encoder mapping \(x \mapsto (\mu, \log \sigma^2)\),

  • a single-hidden-layer decoder mapping \(z \mapsto \hat{x}\),

  • Tanh nonlinearities,

  • linear output logits interpreted by VAE.compute_loss() under a Bernoulli likelihood.

Weights are initialized with a small Gaussian as described in the paper.

 1def make_vae(latent_dim: int, hidden: int = HIDDEN_SIZE) -> VAE:
 2    encoder = nn.Sequential(
 3        nn.Flatten(),
 4        nn.Linear(28 * 28, hidden),
 5        nn.Tanh(),
 6    )
 7
 8    decoder = nn.Sequential(
 9        nn.Linear(latent_dim, hidden),
10        nn.Tanh(),
11        nn.Linear(hidden, 28 * 28),
12        nn.Unflatten(-1, (1, 28, 28)),  # keep last layer linear; VAELoss(bernoulli) expects logits
13    )
14
15    model = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
16    model.build(input_sample=torch.randn(1, 1, 28, 28))
17    model.apply(init_weights_small_normal)
18    return model

The model is built explicitly via VAE.build(), which infers dimension-dependent components from a representative input sample.

ELBO Evaluation#

We evaluate the ELBO over a dataloader by summing the batch-wise ELBO diagnostics returned by VAE.compute_loss().

 1@torch.no_grad()
 2def average_elbo(dataloader: DataLoader, model: VAE) -> float:
 3    model.eval()
 4    total_elbo = 0.0
 5    n = 0
 6
 7    for x, _ in dataloader:
 8        x = x.to(DEVICE)
 9        out = model(x, S=MC_SAMPLES)
10        loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
11        elbo_batch = loss_info.diagnostics['elbo']
12        batch_size = x.size(0)
13        total_elbo += elbo_batch * batch_size
14        n += batch_size
15
16    return total_elbo / n

This routine is used during training to record both training and test ELBO values.

Training Loop#

For each latent dimension \(N_z\), the VAE is trained until a fixed number of training samples have been processed (TARGET_TRAIN_SAMPLES). Periodic evaluations of the ELBO on both train and test sets are logged.

 1def train_one_setting(
 2    latent_dim: int,
 3    train_loader: DataLoader,
 4    test_loader: DataLoader,
 5) -> Tuple[VAE, List[Dict[str, float]]]:
 6    model = make_vae(latent_dim).to(DEVICE)
 7    optimizer = torch.optim.Adagrad(
 8        model.parameters(),
 9        lr=LEARNING_RATE,
10    )
11
12    logs: List[Dict[str, float]] = []
13    samples_seen = 0
14    next_eval = EVAL_EVERY_SAMPLES
15
16    start_time = time.time()
17    while samples_seen < TARGET_TRAIN_SAMPLES:
18        for x, _ in train_loader:
19            x = x.to(DEVICE)
20            model.train()
21            optimizer.zero_grad()
22
23            out = model(x, S=MC_SAMPLES)
24            loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
25            loss_info.objective.backward()
26            optimizer.step()
27
28            batch_size = x.size(0)
29            samples_seen += batch_size
30
31            if samples_seen >= next_eval:
32                train_elbo = average_elbo(train_loader, model)
33                test_elbo = average_elbo(test_loader, model)
34                logs.append(
35                    {
36                        "samples": float(samples_seen),
37                        "train_elbo": float(train_elbo),
38                        "test_elbo": float(test_elbo),
39                    }
40                )
41                elapsed = time.time() - start_time
42                print(
43                    f"N_z={latent_dim:3d} | "
44                    f"samples={samples_seen:>9d} | "
45                    f"ELBO_train={train_elbo:.2f}, "
46                    f"ELBO_test={test_elbo:.2f} | "
47                    f"(elapsed {elapsed:.1f}s)"
48                )
49                next_eval += EVAL_EVERY_SAMPLES
50
51            if samples_seen >= TARGET_TRAIN_SAMPLES:
52                break
53
54    return model, logs

Each training step:

  • Draws \(S\) Monte Carlo samples from \(q(z \mid x)\),

  • Computes the negative ELBO via VAE.compute_loss() with likelihood='bernoulli',

  • Backpropagates gradients through encoder and decoder,

  • Records evaluation metrics periodically.

The ELBO diagnostics contain:

  • elbo – Evidence Lower Bound value,

  • log_likelihood – batch-mean reconstruction term,

  • kl_divergence – batch-mean KL divergence.

Plotting the Results#

After all training runs complete, the ELBO curves are plotted as a function of the number of processed samples (log-scaled), following the appearance of Fig. 2 from Kingma & Welling (2013).

 1def plot_elbo_curves(all_logs: Dict[int, List[Dict[str, float]]]) -> Path:
 2    num_settings = len(all_logs)
 3    fig, axes = plt.subplots(1, num_settings, figsize=(3.0 * num_settings, 3.0))
 4
 5    # Ensure axes is iterable even when num_settings == 1
 6    if num_settings == 1:
 7        axes = np.array([axes])
 8
 9    latent_dims_sorted = sorted(all_logs.keys())
10
11    for i, nz in enumerate(latent_dims_sorted):
12        ax = axes[i]
13        xs = [entry["samples"] for entry in all_logs[nz]]
14        ys_train = [entry["train_elbo"] for entry in all_logs[nz]]
15        ys_test = [entry["test_elbo"] for entry in all_logs[nz]]
16
17        ax.plot(xs, ys_train, label="AEVB (train)", color="r")
18        ax.plot(xs, ys_test, linestyle="--", label="AEVB (test)", color="r")
19
20        ax.set_xscale("log")
21        ax.set_ylim(-150, -95)
22        ax.set_title(f"MNIST, $N_z = {nz}$")
23
24        if i == 0:
25            ax.set_xlabel("# training samples evaluated")
26            ax.set_ylabel(r"$\mathcal{L}$ (ELBO)")
27
28        # Log-scale formatting on x-axis
29        ax.xaxis.set_major_locator(LogLocator(base=10))
30        ax.xaxis.set_major_formatter(LogFormatterMathtext(base=10))
31        ax.xaxis.set_minor_locator(LogLocator(base=10, subs=range(2, 10)))
32        ax.xaxis.set_minor_formatter(NullFormatter())
33
34    axes[0].legend(loc="lower right")
35    plt.tight_layout()
36
37    out_path = Path("vae_mnist_fig2_repro.png")
38    plt.savefig(out_path, dpi=200, bbox_inches="tight")
39    plt.close(fig)
40    print(f"Saved figure -> {out_path}")
41    return out_path

The figure is saved as vae_mnist_fig2_repro.png.

Entry Point#

The main() function orchestrates the full experiment: seeding, data loading, training VAEs for each \(N_z\), logging results, and plotting the final curves.

 1def main() -> None:
 2    print("=== Reproducing Kingma & Welling (2013) Fig. 2 on MNIST ===")
 3    print(f"Using device: {DEVICE}")
 4
 5    set_seed(SEED)
 6    train_loader, test_loader = make_dataloaders(BATCH_SIZE)
 7
 8    all_logs: Dict[int, List[Dict[str, float]]] = {}
 9    for nz in LATENT_DIMS:
10        print(f"\n--- Training VAE with N_z = {nz} ---")
11        _, logs = train_one_setting(nz, train_loader, test_loader)
12        all_logs[nz] = logs
13
14    plot_elbo_curves(all_logs)
15    print("Done.")

The script can be executed directly:

1if __name__ == "__main__":
2    main()

Full Example Script#

For convenience, here is the complete script in one block:

mnist_vae_kingma2013.py#
  1"""
  2Reproduce the MNIST VAE experiment from Kingma & Welling (2013), Fig. 2.
  3
  4We train VAEs with different latent dimensionalities N_z and track the ELBO
  5as a function of the number of training samples seen. 
  6The setup follows the original paper:
  7
  8- One hidden layer MLPs with Tanh activations in encoder/decoder
  9- Hidden size H = 500
 10- Mini-batch size M = 100
 11- Learning rate selected from {0.01, 0.02, 0.1} (we use 0.02 here)
 12- L = 1 Monte Carlo sample for the stochastic latent variable
 13"""
 14
 15from __future__ import annotations
 16
 17import random
 18import time
 19from pathlib import Path
 20from typing import Dict, List, Tuple
 21
 22import matplotlib.pyplot as plt
 23import numpy as np
 24import torch
 25import torch.nn as nn
 26from matplotlib.ticker import (
 27    LogFormatterMathtext,
 28    LogLocator,
 29    NullFormatter,
 30)
 31from torch.utils.data import DataLoader
 32from torchvision import datasets, transforms
 33
 34from pyautoencoder.variational import VAE
 35
 36
 37# ---------------- Configuration ---------------- #
 38LATENT_DIMS: List[int] = [3, 5, 10, 20, 200]  # N_z values (paper)
 39HIDDEN_SIZE: int = 500                        # hidden layer size (encoder/decoder, MNIST)
 40BATCH_SIZE: int = 100                         # M = 100 (paper)
 41LEARNING_RATE: float = 0.02                   # chosen from {0.01, 0.02, 0.1} (paper)
 42MC_SAMPLES: int = 1                           # L = 1 (paper)
 43TARGET_TRAIN_SAMPLES: int = int(1e7)          # stop after this many training samples
 44EVAL_EVERY_SAMPLES: int = int(1e5)            # evaluate and log every this many samples
 45USE_STOCHASTIC_BINARIZATION: bool = False     # binarize x ~ Bernoulli(p = p_x) if True
 46
 47DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 48SEED: int = 1926
 49
 50
 51# ---------------- Utilities ------------------- #
 52def set_seed(seed: int) -> None:
 53    random.seed(seed)
 54    np.random.seed(seed)
 55    torch.manual_seed(seed)
 56    if torch.cuda.is_available():
 57        torch.cuda.manual_seed(seed)
 58
 59
 60def init_weights_small_normal(module: nn.Module) -> None:
 61    if isinstance(module, nn.Linear):
 62        nn.init.normal_(module.weight, mean=0.0, std=0.01)
 63        if module.bias is not None:
 64            nn.init.zeros_(module.bias)
 65
 66
 67# ---------------- Data ----------------------- #
 68def make_dataloaders(
 69    batch_size: int,
 70    use_stochastic_binarization: bool = USE_STOCHASTIC_BINARIZATION,
 71) -> Tuple[DataLoader, DataLoader]:
 72
 73    tfms: List[transforms.Compose | transforms.ToTensor] = [transforms.ToTensor()]
 74
 75    if use_stochastic_binarization:
 76        class BernoulliBinarize:
 77            def __call__(self, x: torch.Tensor) -> torch.Tensor:
 78                # Sample x ~ Bernoulli(p = x); assumes x in [0, 1].
 79                return torch.bernoulli(x)
 80
 81        tfms.append(BernoulliBinarize()) # type: ignore
 82
 83    transform = transforms.Compose(tfms)
 84
 85    train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
 86    test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=transform)
 87
 88    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
 89    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, drop_last=False)
 90    return train_loader, test_loader
 91
 92
 93# ---------------- Model ---------------------- #
 94def make_vae(latent_dim: int, hidden: int = HIDDEN_SIZE) -> VAE:
 95    encoder = nn.Sequential(
 96        nn.Flatten(),
 97        nn.Linear(28 * 28, hidden),
 98        nn.Tanh(),
 99    )
100
101    decoder = nn.Sequential(
102        nn.Linear(latent_dim, hidden),
103        nn.Tanh(),
104        nn.Linear(hidden, 28 * 28),
105        nn.Unflatten(-1, (1, 28, 28)),  # keep last layer linear; VAELoss(bernoulli) expects logits
106    )
107
108    model = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
109    model.build(input_sample=torch.randn(1, 1, 28, 28))
110    model.apply(init_weights_small_normal)
111    return model
112
113# ---------------- Evaluation ----------------- #
114@torch.no_grad()
115def average_elbo(dataloader: DataLoader, model: VAE) -> float:
116    model.eval()
117    total_elbo = 0.0
118    n = 0
119
120    for x, _ in dataloader:
121        x = x.to(DEVICE)
122        out = model(x, S=MC_SAMPLES)
123        loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
124        elbo_batch = loss_info.diagnostics['elbo']
125        batch_size = x.size(0)
126        total_elbo += elbo_batch * batch_size
127        n += batch_size
128
129    return total_elbo / n
130
131
132# ---------------- Training ------------------- #
133def train_one_setting(
134    latent_dim: int,
135    train_loader: DataLoader,
136    test_loader: DataLoader,
137) -> Tuple[VAE, List[Dict[str, float]]]:
138    model = make_vae(latent_dim).to(DEVICE)
139    optimizer = torch.optim.Adagrad(
140        model.parameters(),
141        lr=LEARNING_RATE,
142    )
143
144    logs: List[Dict[str, float]] = []
145    samples_seen = 0
146    next_eval = EVAL_EVERY_SAMPLES
147
148    start_time = time.time()
149    while samples_seen < TARGET_TRAIN_SAMPLES:
150        for x, _ in train_loader:
151            x = x.to(DEVICE)
152            model.train()
153            optimizer.zero_grad()
154
155            out = model(x, S=MC_SAMPLES)
156            loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
157            loss_info.objective.backward()
158            optimizer.step()
159
160            batch_size = x.size(0)
161            samples_seen += batch_size
162
163            if samples_seen >= next_eval:
164                train_elbo = average_elbo(train_loader, model)
165                test_elbo = average_elbo(test_loader, model)
166                logs.append(
167                    {
168                        "samples": float(samples_seen),
169                        "train_elbo": float(train_elbo),
170                        "test_elbo": float(test_elbo),
171                    }
172                )
173                elapsed = time.time() - start_time
174                print(
175                    f"N_z={latent_dim:3d} | "
176                    f"samples={samples_seen:>9d} | "
177                    f"ELBO_train={train_elbo:.2f}, "
178                    f"ELBO_test={test_elbo:.2f} | "
179                    f"(elapsed {elapsed:.1f}s)"
180                )
181                next_eval += EVAL_EVERY_SAMPLES
182
183            if samples_seen >= TARGET_TRAIN_SAMPLES:
184                break
185
186    return model, logs
187
188
189# ---------------- Plotting ------------------- #
190def plot_elbo_curves(all_logs: Dict[int, List[Dict[str, float]]]) -> Path:
191    num_settings = len(all_logs)
192    fig, axes = plt.subplots(1, num_settings, figsize=(3.0 * num_settings, 3.0))
193
194    # Ensure axes is iterable even when num_settings == 1
195    if num_settings == 1:
196        axes = np.array([axes])
197
198    latent_dims_sorted = sorted(all_logs.keys())
199
200    for i, nz in enumerate(latent_dims_sorted):
201        ax = axes[i]
202        xs = [entry["samples"] for entry in all_logs[nz]]
203        ys_train = [entry["train_elbo"] for entry in all_logs[nz]]
204        ys_test = [entry["test_elbo"] for entry in all_logs[nz]]
205
206        ax.plot(xs, ys_train, label="AEVB (train)", color="r")
207        ax.plot(xs, ys_test, linestyle="--", label="AEVB (test)", color="r")
208
209        ax.set_xscale("log")
210        ax.set_ylim(-150, -95)
211        ax.set_title(f"MNIST, $N_z = {nz}$")
212
213        if i == 0:
214            ax.set_xlabel("# training samples evaluated")
215            ax.set_ylabel(r"$\mathcal{L}$ (ELBO)")
216
217        # Log-scale formatting on x-axis
218        ax.xaxis.set_major_locator(LogLocator(base=10))
219        ax.xaxis.set_major_formatter(LogFormatterMathtext(base=10))
220        ax.xaxis.set_minor_locator(LogLocator(base=10, subs=range(2, 10)))
221        ax.xaxis.set_minor_formatter(NullFormatter())
222
223    axes[0].legend(loc="lower right")
224    plt.tight_layout()
225
226    out_path = Path("vae_mnist_fig2_repro.png")
227    plt.savefig(out_path, dpi=200, bbox_inches="tight")
228    plt.close(fig)
229    print(f"Saved figure -> {out_path}")
230    return out_path
231
232
233# ---------------- Main ---------------------- #
234def main() -> None:
235    print("=== Reproducing Kingma & Welling (2013) Fig. 2 on MNIST ===")
236    print(f"Using device: {DEVICE}")
237
238    set_seed(SEED)
239    train_loader, test_loader = make_dataloaders(BATCH_SIZE)
240
241    all_logs: Dict[int, List[Dict[str, float]]] = {}
242    for nz in LATENT_DIMS:
243        print(f"\n--- Training VAE with N_z = {nz} ---")
244        _, logs = train_one_setting(nz, train_loader, test_loader)
245        all_logs[nz] = logs
246
247    plot_elbo_curves(all_logs)
248    print("Done.")
249
250if __name__ == "__main__":
251    main()